Commit 64067980 authored by Sergey Mironov's avatar Sergey Mironov
Browse files

Move a number of FLAGS from lower-level functions to the top-level

parent d0bd3f54
......@@ -100,13 +100,14 @@ class TrainingInstance(object):
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
max_predictions_per_seq, output_files,
gzip_compress):
"""Create TF example files from `TrainingInstance`s."""
writers = []
for output_file in output_files:
writers.append(
tf.io.TFRecordWriter(
output_file, options="GZIP" if FLAGS.gzip_compress else ""))
output_file, options="GZIP" if gzip_compress else ""))
writer_index = 0
......@@ -185,7 +186,7 @@ def create_float_feature(values):
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
max_predictions_per_seq, rng, do_whole_word_mask):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
......@@ -221,7 +222,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask))
rng.shuffle(instances)
return instances
......@@ -229,7 +231,8 @@ def create_training_instances(input_files, tokenizer, max_seq_length,
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
......@@ -327,7 +330,8 @@ def create_instances_from_document(
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
......@@ -347,7 +351,8 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
max_predictions_per_seq, vocab_words, rng,
do_whole_word_mask):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
......@@ -363,7 +368,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
if (do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
......@@ -456,7 +461,7 @@ def main(_):
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
rng, FLAGS.do_whole_word_mask)
output_files = FLAGS.output_file.split(",")
logging.info("*** Writing to output files ***")
......@@ -464,7 +469,8 @@ def main(_):
logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
FLAGS.max_predictions_per_seq, output_files,
FLAGS.gzip_compress)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment