Commit a1fca621 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 270313909
parent 4fa82ae1
......@@ -175,13 +175,6 @@ def main(unused_argv):
input_meta_data["n_layer"] = FLAGS.n_layer
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class
print("DEBUG: ", str(input_meta_data))
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"],
input_meta_data["n_class"]),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
......@@ -190,7 +183,6 @@ def main(unused_argv):
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=get_metric_fn,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
......
......@@ -111,12 +111,6 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config)
def logits_init_fn():
return tf.zeros(
shape=(FLAGS.num_predict, input_meta_data["batch_size_per_core"],
FLAGS.d_model),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train(
strategy=strategy,
......@@ -124,7 +118,6 @@ def main(unused_argv):
input_meta_data=input_meta_data,
eval_fn=None,
metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
......
......@@ -275,10 +275,6 @@ def main(unused_argv):
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top)
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"]), dtype=tf.float32)
logging.info("start reading pickle file...")
with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
eval_features = pickle.load(f)
......@@ -295,7 +291,6 @@ def main(unused_argv):
input_meta_data=input_meta_data,
eval_fn=eval_fn,
metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
......
......@@ -66,7 +66,6 @@ def train(
strategy: tf.distribute.Strategy,
model_fn: Callable,
input_meta_data: Dict,
logits_init_fn: Callable[[], tf.Tensor],
train_input_fn: Callable,
total_training_steps: int,
steps_per_epoch: int,
......@@ -79,7 +78,8 @@ def train(
test_input_fn: Optional[Callable] = None,
init_checkpoint: Optional[Text] = None,
model_dir: Optional[Text] = None,
save_steps: Optional[int] = None):
save_steps: Optional[int] = None,
run_eagerly: Optional[bool] = False):
"""Runs customized training.
Args:
......@@ -87,7 +87,6 @@ def train(
model_fn: The function returns a keras.Model.
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
`n_layer`, `batch_size_per_core` and `d_model`.
logits_init_fn: Function creates a dummy logits tensor.
train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total.
steps_per_epoch: Number of steps to run per epoch. At the end of each
......@@ -110,6 +109,7 @@ def train(
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint.
run_eagerly: Whether to run training eagerly.
Returns:
Last training step logits if training happens, otherwise returns None.
......@@ -117,14 +117,13 @@ def train(
TypeError: if model directory is not specified.
"""
required_arguments = [
logits_init_fn, train_input_fn, total_training_steps, steps_per_epoch,
steps_per_loop, optimizer, learning_rate_fn
train_input_fn, total_training_steps, steps_per_epoch, steps_per_loop,
optimizer, learning_rate_fn
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError(
"`logits_init_fn`, `train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and "
"`learning_rate_fn` are required parameters.")
raise ValueError("`train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and "
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
# pylint: disable=protected-access
......@@ -198,11 +197,8 @@ def train(
optimizer.apply_gradients(zip(clipped, tvars))
if input_meta_data["mem_len"] > 0:
return mem, logits
else:
return logits
return mem
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop.
......@@ -235,20 +231,20 @@ def train(
mems.append(zeros)
return mems
logits = strategy.experimental_run_v2(logits_init_fn)
if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn)
for _ in tf.range(steps):
mem, logits = strategy.experimental_run_v2(
mem = strategy.experimental_run_v2(
_replicated_step, args=(
next(iterator),
mem,
))
else:
for _ in tf.range(steps):
logits = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
return logits
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
if not run_eagerly:
train_steps = tf.function(train_steps)
logging.info("Start training...")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
......@@ -261,15 +257,14 @@ def train(
current_step = optimizer.iterations.numpy()
checkpoint_name = "xlnet_step_{step}.ckpt"
logits = None
while current_step < total_training_steps:
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
logits = train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
......@@ -311,4 +306,4 @@ def train(
logging.info("Running final evaluation after training is complete.")
eval_fn(model, current_step, eval_summary_writer)
return logits
return model
......@@ -1028,7 +1028,6 @@ class PretrainingXLNetModel(tf.keras.Model):
seg_ids = tf.transpose(features['seg_id'], [1, 0])
input_mask = None
perm_mask = tf.transpose(features['perm_mask'], [1, 2, 0])
target_mapping = tf.transpose(features['target_mapping'], [1, 2, 0])
......@@ -1038,23 +1037,24 @@ class PretrainingXLNetModel(tf.keras.Model):
# target mask for LM loss
tgt_mask = tf.transpose(features['target_mask'], [1, 0])
mems = features['mems']
self.transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
mems = features.get('mems', None)
transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
inp_k=input_ids,
seg_id=seg_ids,
input_mask=input_mask,
input_mask=None,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
inp_q=inp_q)
lm_loss = self.lmloss_layer(
hidden=self.transformerxl_output,
hidden=transformerxl_output,
target=target,
lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
target_mask=tgt_mask)
self.add_loss(lm_loss)
return self.new_mems, self.transformerxl_output
return self.new_mems, transformerxl_output
class ClassificationXLNetModel(tf.keras.Model):
......@@ -1117,17 +1117,17 @@ class ClassificationXLNetModel(tf.keras.Model):
label = tf.reshape(features['label_ids'], [bsz_per_core])
mems = features['mems']
mems = features.get('mems', None)
self.transformerxl_output, self.new_mems, self.lookup_table = (
transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.summary = self.summarization_layer(self.transformerxl_output)
self.summary = self.summarization_layer(transformerxl_output)
per_example_loss, logits = self.cl_loss_layer(
hidden=self.summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss))
return self.new_mems, logits
return new_mems, logits
class LMLossLayer(tf.keras.layers.Layer):
......@@ -1349,23 +1349,23 @@ class QAXLNetModel(tf.keras.Model):
cls_index = tf.reshape(features['cls_index'], [-1])
p_mask = features['p_mask']
self.transformerxl_output, self.new_mems, self.lookup_table = (
transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask))
if training:
loss, logits = self.qa_loss_layer(
hidden=self.transformerxl_output,
hidden=transformerxl_output,
p_mask=p_mask,
cls_index=cls_index,
start_positions=features['start_positions'],
end_positions=features['end_positions'],
is_impossible=features['is_impossible'])
self.add_loss(loss)
return self.new_mems, logits
return new_mems, logits
else:
results = self.qa_loss_layer(
hidden=self.transformerxl_output, p_mask=p_mask, cls_index=cls_index)
hidden=transformerxl_output, p_mask=p_mask, cls_index=cls_index)
return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment