Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
451d1325
Commit
451d1325
authored
May 11, 2020
by
Neel Kant
Browse files
Add salient span masking, realm_dataset_utils and significant refactor
parent
5511c258
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
210 additions
and
381 deletions
+210
-381
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+19
-381
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+191
-0
No files found.
megatron/data/realm_dataset.py
View file @
451d1325
import
itertools
import
itertools
import
os
import
random
import
random
import
time
import
numpy
as
np
import
numpy
as
np
# import spacy
import
torch
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
from
megatron
import
get_tokenizer
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.realm_dataset_utils
import
build_realm_training_sample
,
get_block_samples_mapping
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
,
is_start_piece
def
build_simple_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
(
tokens
,
masked_positions
,
masked_labels
,
_
)
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'tokens'
:
tokens_np
,
'labels'
:
labels_np
,
'loss_mask'
:
loss_mask_np
,
'pad_mask'
:
padding_mask_np
}
return
train_sample
# qa_nlp = spacy.load('en_core_web_lg')
def
salient_span_mask
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
np_rng
,
do_permutation
=
False
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes
=
[]
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary
=
[
0
]
*
len
(
tokens
)
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
cls_id
or
token
==
sep_id
:
token_boundary
[
i
]
=
1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
len
(
cand_indexes
)
>=
1
and
not
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
cand_indexes
[
-
1
].
append
(
i
)
else
:
cand_indexes
.
append
([
i
])
if
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
token_boundary
[
i
]
=
1
output_tokens
=
list
(
tokens
)
masked_lm_positions
=
[]
masked_lm_labels
=
[]
ngram_indexes
=
[]
for
idx
in
range
(
len
(
cand_indexes
)):
ngram_index
=
[]
for
n
in
ngrams
:
ngram_index
.
append
(
cand_indexes
[
idx
:
idx
+
n
])
ngram_indexes
.
append
(
ngram_index
)
np_rng
.
shuffle
(
ngram_indexes
)
masked_lms
=
[]
covered_indexes
=
set
()
for
cand_index_set
in
ngram_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
:
continue
n
=
np_rng
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
np_rng
.
random
()
<
0.8
:
masked_token
=
mask_id
else
:
# 10% of the time, keep original
if
np_rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_id_list
[
np_rng
.
randint
(
0
,
len
(
vocab_id_list
))]
output_tokens
[
index
]
=
masked_token
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
np_rng
.
shuffle
(
ngram_indexes
)
select_indexes
=
set
()
if
do_permutation
:
for
cand_index_set
in
ngram_indexes
:
if
len
(
select_indexes
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
continue
n
=
np
.
random
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
while
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
select_indexes
.
add
(
index
)
assert
len
(
select_indexes
)
<=
num_to_predict
select_indexes
=
sorted
(
select_indexes
)
permute_indexes
=
list
(
select_indexes
)
np_rng
.
shuffle
(
permute_indexes
)
orig_token
=
list
(
output_tokens
)
for
src_i
,
tgt_i
in
zip
(
select_indexes
,
permute_indexes
):
output_tokens
[
src_i
]
=
orig_token
[
tgt_i
]
masked_lms
.
append
(
MaskedLmInstance
(
index
=
src_i
,
label
=
orig_token
[
src_i
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
class
REALMDataset
(
Dataset
):
class
REALMDataset
(
Dataset
):
...
@@ -223,8 +30,10 @@ class REALMDataset(Dataset):
...
@@ -223,8 +30,10 @@ class REALMDataset(Dataset):
self
.
short_seq_prob
=
short_seq_prob
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
self
.
samples_mapping
=
get_block_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
)
self
.
tokenizer
=
get_tokenizer
()
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
...
@@ -238,120 +47,23 @@ class REALMDataset(Dataset):
...
@@ -238,120 +47,23 @@ class REALMDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
seq_length
=
self
.
max_seq_length
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
assert
len
(
block
)
>
1
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
sample
=
build_
simple
_training_sample
(
block
,
seq_length
,
sample
=
build_
realm
_training_sample
(
block
,
self
.
max_seq_length
,
self
.
max_seq_length
,
self
.
vocab_id_list
,
self
.
vocab_id_list
,
self
.
vocab_id_to_token_list
,
self
.
vocab_id_to_token_list
,
self
.
cls_id
,
self
.
cls_id
,
self
.
sep_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
self
.
masked_lm_prob
,
np_rng
)
np_rng
)
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
]).
astype
(
np
.
int64
)})
return
sample
return
sample
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
.
append
(
cls_id
)
tokens
.
extend
(
list
(
_tokens
))
tokens
.
append
(
sep_id
)
tokentypes
=
[
0
]
*
len
(
tokens
)
return
tokens
,
tokentypes
def
spacy_ner
(
block_text
):
candidates
=
{}
block
=
qa_nlp
(
block_text
)
starts
=
[]
answers
=
[]
for
ent
in
block
.
ents
:
starts
.
append
(
int
(
ent
.
start_char
))
answers
.
append
(
str
(
ent
.
text
))
candidates
[
'starts'
]
=
starts
candidates
[
'answers'
]
=
answers
class
ICTDataset
(
Dataset
):
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
...
@@ -368,8 +80,9 @@ class ICTDataset(Dataset):
...
@@ -368,8 +80,9 @@ class ICTDataset(Dataset):
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
use_titles
=
use_titles
self
.
use_titles
=
use_titles
self
.
samples_mapping
=
self
.
get_samples_mapping
(
self
.
samples_mapping
=
get_block_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
)
self
.
tokenizer
=
get_tokenizer
()
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
...
@@ -453,78 +166,3 @@ class ICTDataset(Dataset):
...
@@ -453,78 +166,3 @@ class ICTDataset(Dataset):
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
return
tokens
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
megatron/data/realm_dataset_utils.py
View file @
451d1325
import
itertools
import
os
import
random
import
time
import
numpy
as
np
import
spacy
import
torch
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
SPACY_NER
=
spacy
.
load
(
'en_core_web_lg'
)
def
build_realm_training_sample
(
sample
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
try
:
masked_tokens
,
masked_positions
,
masked_labels
=
salient_span_mask
(
tokens
,
mask_id
)
except
TypeError
:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
print
(
"No salient span found."
,
flush
=
True
)
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
'tokens'
:
tokens_np
,
'labels'
:
labels_np
,
'loss_mask'
:
loss_mask_np
,
'pad_mask'
:
padding_mask_np
}
return
train_sample
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
.
append
(
cls_id
)
tokens
.
extend
(
list
(
_tokens
))
tokens
.
append
(
sep_id
)
tokentypes
=
[
0
]
*
len
(
tokens
)
return
tokens
,
tokentypes
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
for
s
in
str_list
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
def
id_to_str_pos_map
(
token_ids
,
tokenizer
):
"""Given a list of ids, return a list of integers which correspond to the starting index
of the corresponding token in the original string (with spaces, without artifacts e.g. ##)"""
token_strs
=
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
pos_map
=
[
0
]
for
i
in
range
(
len
(
token_strs
)
-
1
):
len_prev
=
len
(
token_strs
[
i
])
# do not add the length of the "##"
if
token_strs
[
i
].
startswith
(
"##"
):
len_prev
-=
2
# add the length of the space if needed
if
token_strs
[
i
+
1
].
startswith
(
"##"
):
pos_map
.
append
(
pos_map
[
-
1
]
+
len_prev
)
else
:
pos_map
.
append
(
pos_map
[
-
1
]
+
len_prev
+
1
)
# make sure total size is correct
offset
=
-
2
if
token_strs
[
-
1
].
startswith
(
"##"
)
else
0
total_len
=
pos_map
[
-
1
]
+
len
(
token_strs
[
-
1
])
+
offset
assert
total_len
==
len
(
join_str_list
(
token_strs
))
return
pos_map
def
salient_span_mask
(
tokens
,
mask_id
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenize
(
tokens
))
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
if
len
(
entities
)
==
0
:
return
None
selected_entity
=
np
.
random
.
choice
(
entities
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
token_pos_map
.
index
(
selected_entity
.
start_char
)
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
mask_end
+=
1
labels
=
tokens
.
copy
()
output_tokens
=
tokens
.
copy
()
for
id_idx
in
range
(
mask_start
,
mask_end
):
output_tokens
[
id_idx
]
=
mask_id
return
output_tokens
,
list
(
range
(
mask_start
,
mask_end
)),
labels
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
block_dataset
.
doc_idx
,
block_dataset
.
sizes
,
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment