Commit 6cd426d9 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Support online masking for XLNet

PiperOrigin-RevId: 275408074
parent b0581d0a
This diff is collapsed.
......@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.utils.misc import tpu_lib
flags.DEFINE_integer(
"mask_alpha", default=6, help="How many tokens to form a group.")
flags.DEFINE_integer(
"mask_beta", default=1, help="How many tokens to mask within each group.")
flags.DEFINE_integer(
"num_predict",
default=None,
help="Number of tokens to predict in partial prediction.")
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
# FLAGS for pretrain input preprocessing
flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
flags.DEFINE_float("leak_ratio", default=0.1,
help="Percent of masked tokens that are leaked.")
flags.DEFINE_enum("sample_strategy", default="token_span",
enum_values=["single_token", "whole_word", "token_span",
"word_span"],
help="Stragey used to sample prediction targets.")
flags.DEFINE_integer("max_num_tokens", default=5,
help="Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("min_num_tokens", default=1,
help="Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used.")
flags.DEFINE_integer("max_num_words", default=5,
help="Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
flags.DEFINE_integer("min_num_words", default=1,
help="Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used.")
FLAGS = flags.FLAGS
......@@ -74,11 +91,18 @@ def main(unused_argv):
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d", num_hosts)
online_masking_config = data_utils.OnlineMaskingConfig(
sample_strategy=FLAGS.sample_strategy,
max_num_tokens=FLAGS.max_num_tokens,
min_num_tokens=FLAGS.min_num_tokens,
max_num_words=FLAGS.max_num_words,
min_num_words=FLAGS.min_num_words)
train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
FLAGS.mask_alpha, FLAGS.mask_beta, FLAGS.num_predict, FLAGS.bi_data,
FLAGS.uncased, num_hosts)
FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
num_hosts)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment