Commit a1fca621 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 270313909
parent 4fa82ae1
...@@ -175,13 +175,6 @@ def main(unused_argv): ...@@ -175,13 +175,6 @@ def main(unused_argv):
input_meta_data["n_layer"] = FLAGS.n_layer input_meta_data["n_layer"] = FLAGS.n_layer
input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
input_meta_data["n_class"] = FLAGS.n_class input_meta_data["n_class"] = FLAGS.n_class
print("DEBUG: ", str(input_meta_data))
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"],
input_meta_data["n_class"]),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)): with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train( training_utils.train(
...@@ -190,7 +183,6 @@ def main(unused_argv): ...@@ -190,7 +183,6 @@ def main(unused_argv):
input_meta_data=input_meta_data, input_meta_data=input_meta_data,
eval_fn=eval_fn, eval_fn=eval_fn,
metric_fn=get_metric_fn, metric_fn=get_metric_fn,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
test_input_fn=test_input_fn, test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
......
...@@ -111,12 +111,6 @@ def main(unused_argv): ...@@ -111,12 +111,6 @@ def main(unused_argv):
model_fn = functools.partial(get_pretrainxlnet_model, model_config, model_fn = functools.partial(get_pretrainxlnet_model, model_config,
run_config) run_config)
def logits_init_fn():
return tf.zeros(
shape=(FLAGS.num_predict, input_meta_data["batch_size_per_core"],
FLAGS.d_model),
dtype=tf.float32)
with tf.device(get_primary_cpu_task(use_remote_tpu)): with tf.device(get_primary_cpu_task(use_remote_tpu)):
training_utils.train( training_utils.train(
strategy=strategy, strategy=strategy,
...@@ -124,7 +118,6 @@ def main(unused_argv): ...@@ -124,7 +118,6 @@ def main(unused_argv):
input_meta_data=input_meta_data, input_meta_data=input_meta_data,
eval_fn=None, eval_fn=None,
metric_fn=None, metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
test_input_fn=None, test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
......
...@@ -275,10 +275,6 @@ def main(unused_argv): ...@@ -275,10 +275,6 @@ def main(unused_argv):
model_fn = functools.partial(get_qaxlnet_model, model_config, run_config, model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
FLAGS.start_n_top, FLAGS.end_n_top) FLAGS.start_n_top, FLAGS.end_n_top)
def logits_init_fn():
return tf.zeros(
shape=(input_meta_data["batch_size_per_core"]), dtype=tf.float32)
logging.info("start reading pickle file...") logging.info("start reading pickle file...")
with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f: with tf.io.gfile.GFile(input_meta_data["test_feature_path"], "rb") as f:
eval_features = pickle.load(f) eval_features = pickle.load(f)
...@@ -295,7 +291,6 @@ def main(unused_argv): ...@@ -295,7 +291,6 @@ def main(unused_argv):
input_meta_data=input_meta_data, input_meta_data=input_meta_data,
eval_fn=eval_fn, eval_fn=eval_fn,
metric_fn=None, metric_fn=None,
logits_init_fn=logits_init_fn,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
test_input_fn=test_input_fn, test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
......
...@@ -66,7 +66,6 @@ def train( ...@@ -66,7 +66,6 @@ def train(
strategy: tf.distribute.Strategy, strategy: tf.distribute.Strategy,
model_fn: Callable, model_fn: Callable,
input_meta_data: Dict, input_meta_data: Dict,
logits_init_fn: Callable[[], tf.Tensor],
train_input_fn: Callable, train_input_fn: Callable,
total_training_steps: int, total_training_steps: int,
steps_per_epoch: int, steps_per_epoch: int,
...@@ -79,7 +78,8 @@ def train( ...@@ -79,7 +78,8 @@ def train(
test_input_fn: Optional[Callable] = None, test_input_fn: Optional[Callable] = None,
init_checkpoint: Optional[Text] = None, init_checkpoint: Optional[Text] = None,
model_dir: Optional[Text] = None, model_dir: Optional[Text] = None,
save_steps: Optional[int] = None): save_steps: Optional[int] = None,
run_eagerly: Optional[bool] = False):
"""Runs customized training. """Runs customized training.
Args: Args:
...@@ -87,7 +87,6 @@ def train( ...@@ -87,7 +87,6 @@ def train(
model_fn: The function returns a keras.Model. model_fn: The function returns a keras.Model.
input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`, input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
`n_layer`, `batch_size_per_core` and `d_model`. `n_layer`, `batch_size_per_core` and `d_model`.
logits_init_fn: Function creates a dummy logits tensor.
train_input_fn: Function returns a tf.data.Dataset used for training. train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total. total_training_steps: Number of steps to train in total.
steps_per_epoch: Number of steps to run per epoch. At the end of each steps_per_epoch: Number of steps to run per epoch. At the end of each
...@@ -110,6 +109,7 @@ def train( ...@@ -110,6 +109,7 @@ def train(
model_dir: The directory of model (checkpoints, summaries). model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint. model checkpoint.
run_eagerly: Whether to run training eagerly.
Returns: Returns:
Last training step logits if training happens, otherwise returns None. Last training step logits if training happens, otherwise returns None.
...@@ -117,12 +117,11 @@ def train( ...@@ -117,12 +117,11 @@ def train(
TypeError: if model directory is not specified. TypeError: if model directory is not specified.
""" """
required_arguments = [ required_arguments = [
logits_init_fn, train_input_fn, total_training_steps, steps_per_epoch, train_input_fn, total_training_steps, steps_per_epoch, steps_per_loop,
steps_per_loop, optimizer, learning_rate_fn optimizer, learning_rate_fn
] ]
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError( raise ValueError("`train_input_fn`, `total_training_steps`, "
"`logits_init_fn`, `train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and " "`steps_per_epoch`, `steps_per_loop`, `optimizer` and "
"`learning_rate_fn` are required parameters.") "`learning_rate_fn` are required parameters.")
if not model_dir: if not model_dir:
...@@ -198,11 +197,8 @@ def train( ...@@ -198,11 +197,8 @@ def train(
optimizer.apply_gradients(zip(clipped, tvars)) optimizer.apply_gradients(zip(clipped, tvars))
if input_meta_data["mem_len"] > 0: if input_meta_data["mem_len"] > 0:
return mem, logits return mem
else:
return logits
@tf.function
def train_steps(iterator, steps): def train_steps(iterator, steps):
"""Performs distributed training steps in a loop. """Performs distributed training steps in a loop.
...@@ -235,20 +231,20 @@ def train( ...@@ -235,20 +231,20 @@ def train(
mems.append(zeros) mems.append(zeros)
return mems return mems
logits = strategy.experimental_run_v2(logits_init_fn)
if input_meta_data["mem_len"] > 0: if input_meta_data["mem_len"] > 0:
mem = strategy.experimental_run_v2(cache_fn) mem = strategy.experimental_run_v2(cache_fn)
for _ in tf.range(steps): for _ in tf.range(steps):
mem, logits = strategy.experimental_run_v2( mem = strategy.experimental_run_v2(
_replicated_step, args=( _replicated_step, args=(
next(iterator), next(iterator),
mem, mem,
)) ))
else: else:
for _ in tf.range(steps): for _ in tf.range(steps):
logits = strategy.experimental_run_v2( strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
_replicated_step, args=(next(iterator),))
return logits if not run_eagerly:
train_steps = tf.function(train_steps)
logging.info("Start training...") logging.info("Start training...")
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer) checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
...@@ -261,15 +257,14 @@ def train( ...@@ -261,15 +257,14 @@ def train(
current_step = optimizer.iterations.numpy() current_step = optimizer.iterations.numpy()
checkpoint_name = "xlnet_step_{step}.ckpt" checkpoint_name = "xlnet_step_{step}.ckpt"
logits = None
while current_step < total_training_steps: while current_step < total_training_steps:
train_loss_metric.reset_states() train_loss_metric.reset_states()
if train_metric: if train_metric:
train_metric.reset_states() train_metric.reset_states()
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
logits = train_steps(train_iterator, train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps current_step += steps
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % ( log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
...@@ -311,4 +306,4 @@ def train( ...@@ -311,4 +306,4 @@ def train(
logging.info("Running final evaluation after training is complete.") logging.info("Running final evaluation after training is complete.")
eval_fn(model, current_step, eval_summary_writer) eval_fn(model, current_step, eval_summary_writer)
return logits return model
...@@ -1028,7 +1028,6 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -1028,7 +1028,6 @@ class PretrainingXLNetModel(tf.keras.Model):
seg_ids = tf.transpose(features['seg_id'], [1, 0]) seg_ids = tf.transpose(features['seg_id'], [1, 0])
input_mask = None
perm_mask = tf.transpose(features['perm_mask'], [1, 2, 0]) perm_mask = tf.transpose(features['perm_mask'], [1, 2, 0])
target_mapping = tf.transpose(features['target_mapping'], [1, 2, 0]) target_mapping = tf.transpose(features['target_mapping'], [1, 2, 0])
...@@ -1038,23 +1037,24 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -1038,23 +1037,24 @@ class PretrainingXLNetModel(tf.keras.Model):
# target mask for LM loss # target mask for LM loss
tgt_mask = tf.transpose(features['target_mask'], [1, 0]) tgt_mask = tf.transpose(features['target_mask'], [1, 0])
mems = features['mems']
self.transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model( mems = features.get('mems', None)
transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
inp_k=input_ids, inp_k=input_ids,
seg_id=seg_ids, seg_id=seg_ids,
input_mask=input_mask, input_mask=None,
mems=mems, mems=mems,
perm_mask=perm_mask, perm_mask=perm_mask,
target_mapping=target_mapping, target_mapping=target_mapping,
inp_q=inp_q) inp_q=inp_q)
lm_loss = self.lmloss_layer( lm_loss = self.lmloss_layer(
hidden=self.transformerxl_output, hidden=transformerxl_output,
target=target, target=target,
lookup_table=self.transformerxl_model.embedding_lookup.lookup_table, lookup_table=self.transformerxl_model.embedding_lookup.lookup_table,
target_mask=tgt_mask) target_mask=tgt_mask)
self.add_loss(lm_loss) self.add_loss(lm_loss)
return self.new_mems, self.transformerxl_output return self.new_mems, transformerxl_output
class ClassificationXLNetModel(tf.keras.Model): class ClassificationXLNetModel(tf.keras.Model):
...@@ -1117,17 +1117,17 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -1117,17 +1117,17 @@ class ClassificationXLNetModel(tf.keras.Model):
label = tf.reshape(features['label_ids'], [bsz_per_core]) label = tf.reshape(features['label_ids'], [bsz_per_core])
mems = features['mems'] mems = features.get('mems', None)
self.transformerxl_output, self.new_mems, self.lookup_table = ( transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model( self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems)) inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.summary = self.summarization_layer(self.transformerxl_output) self.summary = self.summarization_layer(transformerxl_output)
per_example_loss, logits = self.cl_loss_layer( per_example_loss, logits = self.cl_loss_layer(
hidden=self.summary, labels=label) hidden=self.summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss)) self.add_loss(tf.keras.backend.mean(per_example_loss))
return self.new_mems, logits return new_mems, logits
class LMLossLayer(tf.keras.layers.Layer): class LMLossLayer(tf.keras.layers.Layer):
...@@ -1349,23 +1349,23 @@ class QAXLNetModel(tf.keras.Model): ...@@ -1349,23 +1349,23 @@ class QAXLNetModel(tf.keras.Model):
cls_index = tf.reshape(features['cls_index'], [-1]) cls_index = tf.reshape(features['cls_index'], [-1])
p_mask = features['p_mask'] p_mask = features['p_mask']
self.transformerxl_output, self.new_mems, self.lookup_table = ( transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model( self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask)) inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask))
if training: if training:
loss, logits = self.qa_loss_layer( loss, logits = self.qa_loss_layer(
hidden=self.transformerxl_output, hidden=transformerxl_output,
p_mask=p_mask, p_mask=p_mask,
cls_index=cls_index, cls_index=cls_index,
start_positions=features['start_positions'], start_positions=features['start_positions'],
end_positions=features['end_positions'], end_positions=features['end_positions'],
is_impossible=features['is_impossible']) is_impossible=features['is_impossible'])
self.add_loss(loss) self.add_loss(loss)
return self.new_mems, logits return new_mems, logits
else: else:
results = self.qa_loss_layer( results = self.qa_loss_layer(
hidden=self.transformerxl_output, p_mask=p_mask, cls_index=cls_index) hidden=transformerxl_output, p_mask=p_mask, cls_index=cls_index)
return results return results
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment