Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8bf2b3be
Commit
8bf2b3be
authored
Apr 02, 2020
by
A. Unique TensorFlower
Browse files
Merge pull request #8355 from stagedml:move-flags
PiperOrigin-RevId: 304544459
parents
f2dcb8d4
64067980
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
10 deletions
+16
-10
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+16
-10
No files found.
official/nlp/data/create_pretraining_data.py
View file @
8bf2b3be
...
@@ -100,13 +100,14 @@ class TrainingInstance(object):
...
@@ -100,13 +100,14 @@ class TrainingInstance(object):
def
write_instance_to_example_files
(
instances
,
tokenizer
,
max_seq_length
,
def
write_instance_to_example_files
(
instances
,
tokenizer
,
max_seq_length
,
max_predictions_per_seq
,
output_files
):
max_predictions_per_seq
,
output_files
,
gzip_compress
):
"""Create TF example files from `TrainingInstance`s."""
"""Create TF example files from `TrainingInstance`s."""
writers
=
[]
writers
=
[]
for
output_file
in
output_files
:
for
output_file
in
output_files
:
writers
.
append
(
writers
.
append
(
tf
.
io
.
TFRecordWriter
(
tf
.
io
.
TFRecordWriter
(
output_file
,
options
=
"GZIP"
if
FLAGS
.
gzip_compress
else
""
))
output_file
,
options
=
"GZIP"
if
gzip_compress
else
""
))
writer_index
=
0
writer_index
=
0
...
@@ -185,7 +186,7 @@ def create_float_feature(values):
...
@@ -185,7 +186,7 @@ def create_float_feature(values):
def
create_training_instances
(
input_files
,
tokenizer
,
max_seq_length
,
def
create_training_instances
(
input_files
,
tokenizer
,
max_seq_length
,
dupe_factor
,
short_seq_prob
,
masked_lm_prob
,
dupe_factor
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
rng
):
max_predictions_per_seq
,
rng
,
do_whole_word_mask
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -221,7 +222,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
...
@@ -221,7 +222,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
instances
.
extend
(
instances
.
extend
(
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
))
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -229,7 +231,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
...
@@ -229,7 +231,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
):
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -327,7 +330,8 @@ def create_instances_from_document(
...
@@ -327,7 +330,8 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
)
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -347,7 +351,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
...
@@ -347,7 +351,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
):
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
):
"""Creates the predictions for the masked LM objective."""
"""Creates the predictions for the masked LM objective."""
cand_indexes
=
[]
cand_indexes
=
[]
...
@@ -363,7 +368,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
...
@@ -363,7 +368,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
# Note that Whole Word Masking does *not* change the training code
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
# over the entire vocabulary.
if
(
FLAGS
.
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
...
@@ -456,7 +461,7 @@ def main(_):
...
@@ -456,7 +461,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
)
rng
,
FLAGS
.
do_whole_word_mask
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
@@ -464,7 +469,8 @@ def main(_):
...
@@ -464,7 +469,8 @@ def main(_):
logging
.
info
(
" %s"
,
output_file
)
logging
.
info
(
" %s"
,
output_file
)
write_instance_to_example_files
(
instances
,
tokenizer
,
FLAGS
.
max_seq_length
,
write_instance_to_example_files
(
instances
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
max_predictions_per_seq
,
output_files
)
FLAGS
.
max_predictions_per_seq
,
output_files
,
FLAGS
.
gzip_compress
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment