Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
73a911c0
Commit
73a911c0
authored
Jul 14, 2020
by
Jeremiah Harmsen
Committed by
A. Unique TensorFlower
Jul 14, 2020
Browse files
Internal change
PiperOrigin-RevId: 321119040
parent
ad166183
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
233 additions
and
58 deletions
+233
-58
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
No files found.
official/nlp/data/create_pretraining_data.py
View file @
73a911c0
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
collections
import
itertools
import
random
import
random
from
absl
import
app
from
absl
import
app
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
max_predictions_per_seq
,
rng
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
# All of the n-grams of this size have been used.
max_predictions_per_seq
,
vocab_words
,
rng
,
if
not
ngrams
[
sz
]:
do_whole_word_mask
):
continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
# Found a usable n-gram! Mark its tokens as masked and add it to return.
for
(
i
,
token
)
in
enumerate
(
tokens
):
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
continue
# Whole Word Masking means that if we mask all of the wordpieces
if
gram_start_pos
is
not
None
:
# corresponding to an original word. When a word has been split into
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
# WordPieces, the first token does not have any marker and any subsequence
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
# tokens are prefixed with ##. So whenever we see the ## token, we
gram_start_pos
=
i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
output_tokens
=
list
(
tokens
)
for
index_set
in
cand_indexes
:
for
gram
in
masked_grams
:
if
len
(
masked_lms
)
>=
num_to_predict
:
# 80% of the time, replace all n-gram tokens with [MASK]
break
if
rng
.
random
()
<
0.8
:
# If adding a whole-word mask would exceed the maximum number of
replacement_action
=
lambda
idx
:
"[MASK]"
# predictions, then just skip this candidate.
else
:
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
# 10% of the time, keep all the original n-gram tokens.
continue
if
rng
.
random
()
<
0.5
:
is_any_index_covered
=
False
replacement_action
=
lambda
idx
:
tokens
[
idx
]
for
index
in
index_set
:
# 10% of the time, replace each n-gram token with a random word.
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
else
:
else
:
# 10% of the time, keep original
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
@@ -467,7 +642,7 @@ def main(_):
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment