Commit b045ce7d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272777104
parent 0f176f6f
......@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
......
......@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
def get_pretrainxlnet_model(model_config, run_config):
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model")
return model
return modeling.PretrainingXLNetModel(
use_proj=True,
xlnet_config=model_config,
run_config=run_config,
name="model")
def main(unused_argv):
......@@ -69,8 +72,7 @@ def main(unused_argv):
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d",
num_hosts)
logging.info("***** Number of hosts used : %d", num_hosts)
train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
......
......@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁"
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
class InputFeatures(object):
"""A single set of features of data."""
......@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(SEG_ID_P)
segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(0)
paragraph_len = len(tokens)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_P)
segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(1)
# note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A
for token in query_tokens:
tokens.append(token)
segment_ids.append(SEG_ID_Q)
segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_Q)
segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1)
cls_index = len(segment_ids)
tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS)
segment_ids.append(data_utils.SEG_ID_CLS)
p_mask.append(0)
input_ids = tokens
......@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(1)
segment_ids.append(SEG_ID_PAD)
segment_ids.append(data_utils.SEG_ID_PAD)
p_mask.append(1)
assert len(input_ids) == max_seq_length
......
......@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
kwargs = dict(
is_training=is_training,
use_tpu=flags.use_tpu,
use_bfloat16=flags.use_bfloat16,
dropout=flags.dropout,
dropout_att=flags.dropout_att,
init_method=flags.init_method,
......@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
return RunConfig(**kwargs)
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
class XLNetConfig(object):
"""Configs for XLNet model.
......@@ -131,7 +131,6 @@ class RunConfig(object):
def __init__(self,
is_training,
use_tpu,
use_bfloat16,
dropout,
dropout_att,
init_method='normal',
......@@ -141,13 +140,13 @@ class RunConfig(object):
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
same_length=False,
use_cls_mask=True):
"""Initializes RunConfig.
Args:
is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform".
......@@ -164,6 +163,7 @@ class RunConfig(object):
-1 means no clamping.
same_length: bool, whether to use the same attention length
for each token.
use_cls_mask: bool, whether to introduce cls mask.
"""
self.init_method = init_method
......@@ -173,9 +173,9 @@ class RunConfig(object):
self.dropout = dropout
self.dropout_att = dropout_att
self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.use_cls_mask = use_cls_mask
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment