Commit 8bf2b3be authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8355 from stagedml:move-flags

PiperOrigin-RevId: 304544459
parents f2dcb8d4 64067980
...@@ -100,13 +100,14 @@ class TrainingInstance(object): ...@@ -100,13 +100,14 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length, def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files): max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s.""" """Create TF example files from `TrainingInstance`s."""
writers = [] writers = []
for output_file in output_files: for output_file in output_files:
writers.append( writers.append(
tf.io.TFRecordWriter( tf.io.TFRecordWriter(
output_file, options="GZIP" if FLAGS.gzip_compress else "")) output_file, options="GZIP" if gzip_compress else ""))
writer_index = 0 writer_index = 0
...@@ -185,7 +186,7 @@ def create_float_feature(values): ...@@ -185,7 +186,7 @@ def create_float_feature(values):
def create_training_instances(input_files, tokenizer, max_seq_length, def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob, dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng): max_predictions_per_seq, rng, do_whole_word_mask):
"""Create `TrainingInstance`s from raw text.""" """Create `TrainingInstance`s from raw text."""
all_documents = [[]] all_documents = [[]]
...@@ -221,7 +222,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, ...@@ -221,7 +222,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
instances.extend( instances.extend(
create_instances_from_document( create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
rng.shuffle(instances) rng.shuffle(instances)
return instances return instances
...@@ -229,7 +231,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length, ...@@ -229,7 +231,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
def create_instances_from_document( def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob, all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng): masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates `TrainingInstance`s for a single document.""" """Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index] document = all_documents[document_index]
...@@ -327,7 +330,8 @@ def create_instances_from_document( ...@@ -327,7 +330,8 @@ def create_instances_from_document(
(tokens, masked_lm_positions, (tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions( masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
instance = TrainingInstance( instance = TrainingInstance(
tokens=tokens, tokens=tokens,
segment_ids=segment_ids, segment_ids=segment_ids,
...@@ -347,7 +351,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ...@@ -347,7 +351,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def create_masked_lm_predictions(tokens, masked_lm_prob, def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng): max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective.""" """Creates the predictions for the masked LM objective."""
cand_indexes = [] cand_indexes = []
...@@ -363,7 +368,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob, ...@@ -363,7 +368,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")): token.startswith("##")):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else: else:
...@@ -456,7 +461,7 @@ def main(_): ...@@ -456,7 +461,7 @@ def main(_):
instances = create_training_instances( instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng) rng, FLAGS.do_whole_word_mask)
output_files = FLAGS.output_file.split(",") output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***") logging.info("*** Writing to output files ***")
...@@ -464,7 +469,8 @@ def main(_): ...@@ -464,7 +469,8 @@ def main(_):
logging.info(" %s", output_file) logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files) FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment