Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e7045139
Commit
e7045139
authored
Jun 24, 2020
by
Neel Kant
Browse files
Add stanza NER salient span masking
parent
b8bb0b49
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
122 additions
and
15 deletions
+122
-15
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+7
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+13
-2
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+88
-7
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-5
megatron/training.py
megatron/training.py
+5
-0
No files found.
megatron/arguments.py
View file @
e7045139
...
...
@@ -394,6 +394,10 @@ def _add_data_args(parser):
group
.
add_argument
(
'--use-random-spans'
,
action
=
'store_true'
)
group
.
add_argument
(
'--allow-trivial-doc'
,
action
=
'store_true'
)
group
.
add_argument
(
'--ner-data-path'
,
type
=
str
,
default
=
None
)
group
.
add_argument
(
'--cased-data-path'
,
type
=
str
,
default
=
None
,
help
=
'path to cased data to use for NER salient span masking'
)
group
.
add_argument
(
'--cased-vocab'
,
type
=
str
,
default
=
None
,
help
=
'path to cased vocab file to use for NER salient span masking'
)
return
parser
...
...
megatron/data/dataset_utils.py
View file @
e7045139
...
...
@@ -387,7 +387,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
num_tokens
=
len
(
tokens
)
padding_length
=
max_seq_length
-
num_tokens
assert
padding_length
>=
0
assert
len
(
tokentypes
)
==
num_tokens
assert
len
(
tokentypes
)
==
num_tokens
,
(
len
(
tokentypes
),
num_tokens
)
assert
len
(
masked_positions
)
==
len
(
masked_labels
),
(
len
(
masked_positions
),
len
(
masked_labels
))
# Tokens and token types.
...
...
@@ -491,6 +491,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl
,
skip_warmup
)
kwargs
.
update
({
'ner_dataset'
:
ner_dataset
})
elif
args
.
cased_data_path
is
not
None
:
cased_dataset
=
get_indexed_dataset_
(
args
.
cased_data_path
,
data_impl
,
skip_warmup
)
kwargs
.
update
({
'cased_block_dataset'
:
cased_dataset
,
'cased_vocab'
:
args
.
cased_vocab
})
dataset
=
REALMDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
...
...
megatron/data/realm_dataset.py
View file @
e7045139
...
...
@@ -20,7 +20,7 @@ class REALMDataset(Dataset):
"""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
,
ner_dataset
=
None
):
max_seq_length
,
short_seq_prob
,
seed
,
ner_dataset
=
None
,
cased_block_dataset
=
None
,
cased_vocab
=
None
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
...
...
@@ -29,7 +29,13 @@ class REALMDataset(Dataset):
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
ner_dataset
=
ner_dataset
self
.
cased_block_dataset
=
cased_block_dataset
self
.
cased_tokenizer
=
None
if
self
.
cased_block_dataset
is
not
None
:
from
megatron.tokenizer.tokenizer
import
BertWordPieceTokenizer
self
.
cased_tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
cased_vocab
,
lower_case
=
False
)
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
...
...
@@ -49,7 +55,6 @@ class REALMDataset(Dataset):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
# print([len(list(self.block_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
assert
len
(
block
)
>
1
block_ner_mask
=
None
...
...
@@ -57,6 +62,10 @@ class REALMDataset(Dataset):
block_ner_mask
=
[
list
(
self
.
ner_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
# print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
cased_tokens
=
None
if
self
.
cased_block_dataset
is
not
None
:
cased_tokens
=
[
list
(
self
.
cased_block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
sample
=
build_realm_training_sample
(
block
,
...
...
@@ -69,6 +78,8 @@ class REALMDataset(Dataset):
self
.
pad_id
,
self
.
masked_lm_prob
,
block_ner_mask
,
cased_tokens
,
self
.
cased_tokenizer
,
np_rng
)
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
return
sample
...
...
megatron/data/realm_dataset_utils.py
View file @
e7045139
...
...
@@ -6,6 +6,12 @@ import time
import
numpy
as
np
import
spacy
import
torch
try
:
import
stanza
processors_dict
=
{
'tokenize'
:
'default'
,
'mwt'
:
'default'
,
'ner'
:
'conll03'
}
stanza_pipeline
=
stanza
.
Pipeline
(
'en'
,
processors
=
processors_dict
,
use_gpu
=
True
)
except
:
pass
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
...
...
@@ -16,7 +22,8 @@ SPACY_NER = spacy.load('en_core_web_lg')
def
build_realm_training_sample
(
sample
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
block_ner_mask
,
np_rng
):
masked_lm_prob
,
block_ner_mask
,
cased_tokens
,
cased_tokenizer
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
...
...
@@ -35,8 +42,20 @@ def build_realm_training_sample(sample, max_seq_length,
masked_tokens
,
masked_positions
,
masked_labels
=
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
)
else
:
try
:
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
except
TypeError
:
if
args
.
cased_data_path
is
not
None
:
total_len
=
sum
(
len
(
l
)
for
l
in
sample
)
# truncate the last sentence to make it so that the whole thing has length max_seq_length - 2
if
total_len
>
max_seq_length
-
2
:
offset
=
-
(
total_len
-
(
max_seq_length
-
2
))
sample
[
-
1
]
=
sample
[
-
1
][:
offset
]
masked_tokens
,
masked_positions
,
masked_labels
=
get_stanza_ner_mask
(
sample
,
cased_tokens
,
cased_tokenizer
,
cls_id
,
sep_id
,
mask_id
)
else
:
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
except
:
# print("+" * 100, flush=True)
# print('could not create salient span', flush=True)
# print("+" * 100, flush=True)
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
...
...
@@ -57,6 +76,67 @@ def build_realm_training_sample(sample, max_seq_length,
return
train_sample
def
get_stanza_ner_mask
(
tokens
,
cased_tokens
,
cased_tokenizer
,
cls_id
,
sep_id
,
mask_id
):
"""Use stanza to generate NER salient span masks in the loop"""
# assuming that the default tokenizer is uncased.
uncased_tokenizer
=
get_tokenizer
()
block_ner_mask
=
[]
for
cased_sent_ids
,
uncased_sent_ids
in
zip
(
cased_tokens
,
tokens
):
# print('>')
token_pos_map
=
id_to_str_pos_map
(
uncased_sent_ids
,
uncased_tokenizer
)
# get the cased string and do NER with both toolkits
cased_sent_str
=
join_str_list
(
cased_tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
cased_sent_ids
))
entities
=
stanza_pipeline
(
cased_sent_str
).
ents
spacy_entities
=
SPACY_NER
(
cased_sent_str
).
ents
# CoNLL doesn't do dates, so we scan with spacy to get the dates.
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
'CLS'
]
entities
.
extend
([
e
for
e
in
spacy_entities
if
(
e
.
text
!=
'CLS'
and
e
.
label_
==
'DATE'
)])
# randomize which entities to look at, and set a target of 12% of tokens being masked
entity_indices
=
np
.
arange
(
len
(
entities
))
np
.
random
.
shuffle
(
entity_indices
)
target_num_masks
=
int
(
len
(
cased_sent_ids
)
*
0.12
)
masked_positions
=
[]
for
entity_idx
in
entity_indices
[:
3
]:
# if we have enough masks then break.
if
len
(
masked_positions
)
>
target_num_masks
:
break
selected_entity
=
entities
[
entity_idx
]
# print(">> selected entity: {}".format(selected_entity.text), flush=True)
mask_start
=
mask_end
=
0
set_mask_start
=
False
# loop for checking where mask should start and end.
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
if
token_pos_map
[
mask_start
]
>
selected_entity
.
start_char
:
set_mask_start
=
True
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
# add offset to indices since our input was list of sentences
masked_positions
.
extend
(
range
(
mask_start
-
1
,
mask_end
))
ner_mask
=
[
0
]
*
len
(
uncased_sent_ids
)
for
pos
in
masked_positions
:
ner_mask
[
pos
]
=
1
block_ner_mask
.
extend
(
ner_mask
)
# len_tokens = [len(l) for l in tokens]
# print(len_tokens, flush=True)
# print([sum(len_tokens[:i + 1]) for i in range(len(tokens))], flush=True)
tokens
=
list
(
itertools
.
chain
(
*
tokens
))
tokens
=
[
cls_id
]
+
tokens
+
[
sep_id
]
block_ner_mask
=
[
0
]
+
block_ner_mask
+
[
0
]
return
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
)
def
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
):
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
...
...
@@ -65,16 +145,17 @@ def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
masked_positions
=
[]
masked_labels
=
[]
for
i
in
range
(
len
(
tokens
)):
if
block_ner_mask
[
i
]
==
1
:
masked_positions
.
append
(
i
)
masked_labels
.
append
(
tokens
[
i
])
masked_tokens
[
i
]
=
mask_id
# print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True)
# print("\nTOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)) + '\n',
# "FRAC_MASKED: {}\n".format(len(masked_labels) / len(tokens)),
# "-" * 100 + '\n',
# flush=True)
return
masked_tokens
,
masked_positions
,
masked_labels
...
...
megatron/tokenizer/tokenizer.py
View file @
e7045139
...
...
@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'
.
format
(
self
.
name
))
class
_
BertWordPieceTokenizer
(
AbstractTokenizer
):
class
BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
...
...
megatron/training.py
View file @
e7045139
...
...
@@ -87,6 +87,11 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args
=
get_args
()
timers
=
get_timers
()
if
args
.
rank
==
0
and
args
.
cased_data_path
is
not
None
:
import
stanza
stanza
.
download
(
'en'
,
processors
=
{
'ner'
:
'conll03'
},
dir
=
'stanza'
)
# Model, optimizer, and learning rate.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment