Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
840759b8
Commit
840759b8
authored
Apr 03, 2020
by
Neel Kant
Browse files
Lint megatron/data/dataset_utils.py
parent
63262827
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
172 additions
and
171 deletions
+172
-171
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+172
-171
No files found.
megatron/data/dataset_utils.py
View file @
840759b8
...
...
@@ -132,6 +132,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens
.
pop
()
return
True
def
create_tokens_and_tokentypes
(
tokens_a
,
tokens_b
,
cls_id
,
sep_id
):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...
...
@@ -163,12 +164,12 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def
is_start_piece
(
piece
):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return
not
piece
.
startswith
(
"##"
)
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return
not
piece
.
startswith
(
"##"
)
def
create_masked_lm_predictions
(
tokens
,
...
...
@@ -181,178 +182,178 @@ def create_masked_lm_predictions(tokens,
do_whole_word_mask
=
True
,
favor_longer_ngram
=
False
,
do_permutation
=
False
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes
=
[]
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary
=
[
0
]
*
len
(
tokens
)
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
cls_id
or
token
==
sep_id
:
token_boundary
[
i
]
=
1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
not
is_start_piece
(
vocab_id_to_token_dict
[
token
])):
cand_indexes
[
-
1
].
append
(
i
)
else
:
cand_indexes
.
append
([
i
])
if
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
token_boundary
[
i
]
=
1
output_tokens
=
list
(
tokens
)
masked_lm_positions
=
[]
masked_lm_labels
=
[]
if
masked_lm_prob
==
0
:
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
ngrams
=
np
.
arange
(
1
,
max_ngrams
+
1
,
dtype
=
np
.
int64
)
pvals
=
1.
/
np
.
arange
(
1
,
max_ngrams
+
1
)
pvals
/=
pvals
.
sum
(
keepdims
=
True
)
if
favor_longer_ngram
:
pvals
=
pvals
[::
-
1
]
ngram_indexes
=
[]
for
idx
in
range
(
len
(
cand_indexes
)):
ngram_index
=
[]
for
n
in
ngrams
:
ngram_index
.
append
(
cand_indexes
[
idx
:
idx
+
n
])
ngram_indexes
.
append
(
ngram_index
)
np_rng
.
shuffle
(
ngram_indexes
)
masked_lms
=
[]
covered_indexes
=
set
()
for
cand_index_set
in
ngram_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
:
continue
n
=
np_rng
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
np_rng
.
random
()
<
0.8
:
masked_token
=
mask_id
else
:
# 10% of the time, keep original
if
np_rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes
=
[]
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary
=
[
0
]
*
len
(
tokens
)
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
cls_id
or
token
==
sep_id
:
token_boundary
[
i
]
=
1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
not
is_start_piece
(
vocab_id_to_token_dict
[
token
])):
cand_indexes
[
-
1
].
append
(
i
)
else
:
masked_token
=
vocab_id_list
[
np_rng
.
randint
(
0
,
len
(
vocab_id_list
))]
cand_indexes
.
append
([
i
])
if
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
token_boundary
[
i
]
=
1
output_tokens
[
index
]
=
masked_
token
output_tokens
=
list
(
token
s
)
masked_lm
s
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lm
_positions
=
[]
masked_lm_labels
=
[]
np_rng
.
shuffle
(
ngram_indexes
)
if
masked_lm_prob
==
0
:
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
select_indexes
=
set
()
if
do_permutation
:
for
cand_index_set
in
ngram_indexes
:
if
len
(
select_indexes
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
continue
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
ngrams
=
np
.
arange
(
1
,
max_ngrams
+
1
,
dtype
=
np
.
int64
)
pvals
=
1.
/
np
.
arange
(
1
,
max_ngrams
+
1
)
pvals
/=
pvals
.
sum
(
keepdims
=
True
)
if
favor_longer_ngram
:
pvals
=
pvals
[::
-
1
]
ngram_indexes
=
[]
for
idx
in
range
(
len
(
cand_indexes
)):
ngram_index
=
[]
for
n
in
ngrams
:
ngram_index
.
append
(
cand_indexes
[
idx
:
idx
+
n
])
ngram_indexes
.
append
(
ngram_index
)
n
=
np
.
random
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
np_rng
.
shuffle
(
ngram_indexes
)
while
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
masked_lms
=
[]
covered_indexes
=
set
()
for
cand_index_set
in
ngram_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
:
continue
n
=
np_rng
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
select_indexes
.
add
(
index
)
assert
len
(
select_indexes
)
<=
num_to_predict
select_indexes
=
sorted
(
select_indexes
)
permute_indexes
=
list
(
select_indexes
)
np_rng
.
shuffle
(
permute_indexes
)
orig_token
=
list
(
output_tokens
)
for
src_i
,
tgt_i
in
zip
(
select_indexes
,
permute_indexes
):
output_tokens
[
src_i
]
=
orig_token
[
tgt_i
]
masked_lms
.
append
(
MaskedLmInstance
(
index
=
src_i
,
label
=
orig_token
[
src_i
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
np_rng
.
random
()
<
0.8
:
masked_token
=
mask_id
else
:
# 10% of the time, keep original
if
np_rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_id_list
[
np_rng
.
randint
(
0
,
len
(
vocab_id_list
))]
output_tokens
[
index
]
=
masked_token
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
np_rng
.
shuffle
(
ngram_indexes
)
select_indexes
=
set
()
if
do_permutation
:
for
cand_index_set
in
ngram_indexes
:
if
len
(
select_indexes
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
continue
n
=
np
.
random
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
while
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
select_indexes
.
add
(
index
)
assert
len
(
select_indexes
)
<=
num_to_predict
select_indexes
=
sorted
(
select_indexes
)
permute_indexes
=
list
(
select_indexes
)
np_rng
.
shuffle
(
permute_indexes
)
orig_token
=
list
(
output_tokens
)
for
src_i
,
tgt_i
in
zip
(
select_indexes
,
permute_indexes
):
output_tokens
[
src_i
]
=
orig_token
[
tgt_i
]
masked_lms
.
append
(
MaskedLmInstance
(
index
=
src_i
,
label
=
orig_token
[
src_i
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
def
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
...
...
@@ -367,12 +368,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
# Tokens and token types.
filler
=
[
pad_id
]
*
padding_length
filler
=
[
pad_id
]
*
padding_length
tokens_np
=
np
.
array
(
tokens
+
filler
,
dtype
=
np
.
int64
)
tokentypes_np
=
np
.
array
(
tokentypes
+
filler
,
dtype
=
np
.
int64
)
# Padding mask.
padding_mask_np
=
np
.
array
([
1
]
*
num_tokens
+
[
0
]
*
padding_length
,
padding_mask_np
=
np
.
array
([
1
]
*
num_tokens
+
[
0
]
*
padding_length
,
dtype
=
np
.
int64
)
# Lables and loss mask.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment