Commit b045ce7d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272777104
parent 0f176f6f
...@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"] ...@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"] SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"] MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"] EOD_ID = special_symbols["<eod>"]
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
def file_based_input_fn_builder(input_file, name_to_features, batch_size, def file_based_input_fn_builder(input_file, name_to_features, batch_size,
......
...@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS ...@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
def get_pretrainxlnet_model(model_config, run_config): def get_pretrainxlnet_model(model_config, run_config):
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model") return modeling.PretrainingXLNetModel(
return model use_proj=True,
xlnet_config=model_config,
run_config=run_config,
name="model")
def main(unused_argv): def main(unused_argv):
...@@ -69,8 +72,7 @@ def main(unused_argv): ...@@ -69,8 +72,7 @@ def main(unused_argv):
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d", logging.info("***** Number of hosts used : %d", num_hosts)
num_hosts)
train_input_fn = functools.partial( train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
......
...@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils ...@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁" SPIECE_UNDERLINE = u"▁"
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
class InputFeatures(object): class InputFeatures(object):
"""A single set of features of data.""" """A single set of features of data."""
...@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride, ...@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
split_token_index) split_token_index)
token_is_max_context[len(tokens)] = is_max_context token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(SEG_ID_P) segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(0) p_mask.append(0)
paragraph_len = len(tokens) paragraph_len = len(tokens)
tokens.append(data_utils.SEP_ID) tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_P) segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(1) p_mask.append(1)
# note(zhiliny): we put P before Q # note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A # because during pretraining, B is always shorter than A
for token in query_tokens: for token in query_tokens:
tokens.append(token) tokens.append(token)
segment_ids.append(SEG_ID_Q) segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1) p_mask.append(1)
tokens.append(data_utils.SEP_ID) tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_Q) segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1) p_mask.append(1)
cls_index = len(segment_ids) cls_index = len(segment_ids)
tokens.append(data_utils.CLS_ID) tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS) segment_ids.append(data_utils.SEG_ID_CLS)
p_mask.append(0) p_mask.append(0)
input_ids = tokens input_ids = tokens
...@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride, ...@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(1) input_mask.append(1)
segment_ids.append(SEG_ID_PAD) segment_ids.append(data_utils.SEG_ID_PAD)
p_mask.append(1) p_mask.append(1)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
......
...@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags): ...@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
kwargs = dict( kwargs = dict(
is_training=is_training, is_training=is_training,
use_tpu=flags.use_tpu, use_tpu=flags.use_tpu,
use_bfloat16=flags.use_bfloat16,
dropout=flags.dropout, dropout=flags.dropout,
dropout_att=flags.dropout_att, dropout_att=flags.dropout_att,
init_method=flags.init_method, init_method=flags.init_method,
...@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags): ...@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
return RunConfig(**kwargs) return RunConfig(**kwargs)
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
class XLNetConfig(object): class XLNetConfig(object):
"""Configs for XLNet model. """Configs for XLNet model.
...@@ -131,7 +131,6 @@ class RunConfig(object): ...@@ -131,7 +131,6 @@ class RunConfig(object):
def __init__(self, def __init__(self,
is_training, is_training,
use_tpu, use_tpu,
use_bfloat16,
dropout, dropout,
dropout_att, dropout_att,
init_method='normal', init_method='normal',
...@@ -141,13 +140,13 @@ class RunConfig(object): ...@@ -141,13 +140,13 @@ class RunConfig(object):
reuse_len=None, reuse_len=None,
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False): same_length=False,
use_cls_mask=True):
"""Initializes RunConfig. """Initializes RunConfig.
Args: Args:
is_training: bool, whether in training mode. is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used. use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate. dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities. dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform". init_method: str, the initialization scheme, either "normal" or "uniform".
...@@ -164,6 +163,7 @@ class RunConfig(object): ...@@ -164,6 +163,7 @@ class RunConfig(object):
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length same_length: bool, whether to use the same attention length
for each token. for each token.
use_cls_mask: bool, whether to introduce cls mask.
""" """
self.init_method = init_method self.init_method = init_method
...@@ -173,9 +173,9 @@ class RunConfig(object): ...@@ -173,9 +173,9 @@ class RunConfig(object):
self.dropout = dropout self.dropout = dropout
self.dropout_att = dropout_att self.dropout_att = dropout_att
self.use_tpu = use_tpu self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.mem_len = mem_len self.mem_len = mem_len
self.reuse_len = reuse_len self.reuse_len = reuse_len
self.bi_data = bi_data self.bi_data = bi_data
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.use_cls_mask = use_cls_mask
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment