Commit e170a8ba authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 266413847
parent 765da424
......@@ -402,7 +402,7 @@ class SequenceBeamSearch(object):
topk_ids = topk_indices % self.vocab_size
if self.padded_decode:
topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.tensor_scatter_nd_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
......
......@@ -54,6 +54,7 @@ from __future__ import print_function
import math
import os
from absl import logging
import tensorflow as tf
from official.transformer.v2 import misc
......@@ -193,7 +194,7 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
static_batch=False, num_replicas=1):
static_batch=False, num_replicas=1, ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets".
Args:
......@@ -219,12 +220,17 @@ def _read_and_batch_from_files(
batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False.
ctx: Input context.
Returns:
tf.data.Dataset object containing examples loaded from the files.
"""
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
if ctx and ctx.num_input_pipelines > 1:
logging.info("Shard %d of the dataset.", ctx.input_pipeline_id)
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
# Read files and interleave results. When training, the order of the examples
# will be non-deterministic.
options = tf.data.Options()
......@@ -247,7 +253,7 @@ def _read_and_batch_from_files(
# First calculate batch size (token number) per worker, then divide it
# into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy.
((batch_size // num_replicas) // max_length) * num_replicas,
int(batch_size // num_replicas // max_length * num_replicas),
([max_length], [max_length]), drop_remainder=True)
else:
# Group and batch such that each batch has examples of similar length.
......@@ -276,7 +282,7 @@ def _generate_synthetic_data(params):
return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params):
def train_input_fn(params, ctx=None):
"""Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(params["data_dir"] or "", "*train*")
if params["use_synthetic_data"]:
......@@ -285,10 +291,10 @@ def train_input_fn(params):
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"])
num_replicas=params["num_gpus"], ctx=ctx)
def eval_input_fn(params):
def eval_input_fn(params, ctx=None):
"""Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
if params["use_synthetic_data"]:
......@@ -296,7 +302,8 @@ def eval_input_fn(params):
return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1,
static_batch=params["static_batch"], num_replicas=params["num_gpus"])
static_batch=params["static_batch"], num_replicas=params["num_gpus"],
ctx=ctx)
def map_data_for_transformer_fn(x, y):
......
......@@ -182,10 +182,6 @@ def define_transformer_flags():
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
......
......@@ -146,7 +146,6 @@ class TransformerTask(object):
params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch
......@@ -210,6 +209,15 @@ class TransformerTask(object):
with distribution_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer()
current_step = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
current_step = opt.iterations.numpy()
if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
......@@ -226,7 +234,7 @@ class TransformerTask(object):
train_ds = (
self.distribution_strategy
.experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params)))
lambda ctx: data_pipeline.train_input_fn(params, ctx)))
else:
train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn
......@@ -275,40 +283,33 @@ class TransformerTask(object):
self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),))
if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals
cased_score, uncased_score = None, None
cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1):
print("Start train iteration:{}/{}".format(i, iterations))
while current_step < flags_obj.train_steps:
remaining_steps = flags_obj.train_steps - current_step
train_steps_per_eval = (
remaining_steps if remaining_steps < flags_obj.steps_between_evals
else flags_obj.steps_between_evals)
current_iteration = current_step // flags_obj.steps_between_evals
print("Start train iteration at global step:{}".format(current_step))
history = None
if params["use_ctl"]:
if not self.use_tpu:
raise NotImplementedError(
"Custom training loop on GPUs is not implemented.")
train_steps_per_eval = tf.convert_to_tensor(
flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps.
train_steps(train_ds_iterator, train_steps_per_eval)
train_steps(train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
current_step += train_steps_per_eval
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps,
train_loss)
current_step, flags_obj.train_steps, train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals)))
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
......@@ -316,24 +317,22 @@ class TransformerTask(object):
"Keras model.fit on TPUs is not implemented.")
history = model.fit(
train_ds,
initial_epoch=i - 1,
epochs=i,
steps_per_epoch=flags_obj.steps_between_evals,
initial_epoch=current_iteration,
epochs=current_iteration + 1,
steps_per_epoch=train_steps_per_eval,
callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
current_step += train_steps_per_eval
logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format(
i,
iterations,
i*flags_obj.steps_between_evals))
print("End train iteration at global step:{}".format(current_step))
if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score])
uncased_score_history.append([i, uncased_score])
cased_score_history.append([current_iteration + 1, cased_score])
uncased_score_history.append([current_iteration + 1, uncased_score])
stats = ({
"loss": train_loss
......@@ -347,12 +346,13 @@ class TransformerTask(object):
def eval(self):
"""Evaluates the model."""
if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible(
self.predict_model,
tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary()
with distribution_utils.get_strategy_scope(self.distribution_strategy):
if not self.predict_model:
self.predict_model = transformer.create_model(self.params, False)
self._load_weights_if_possible(
self.predict_model,
tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary()
return evaluate_and_log_bleu(
self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
......@@ -430,7 +430,7 @@ class TransformerTask(object):
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt
......
......@@ -128,8 +128,10 @@ def translate_file(model,
def _step_fn(inputs):
"""Per replica step function."""
val_outputs, _ = model([inputs], training=False)
return val_outputs
tag = inputs[0]
val_inputs = inputs[1]
val_outputs, _ = model([val_inputs], training=False)
return tag, val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
......@@ -140,17 +142,25 @@ def translate_file(model,
for i, text in enumerate(input_generator()):
if distribution_strategy:
text = np.reshape(text, [num_replicas, local_batch_size, -1])
# Add tag to the input of each replica with the reordering logic after
# outputs, to ensure the output order matches the input order.
text = [
tf.convert_to_tensor(per_replica_text) for per_replica_text in text
[tf.convert_to_tensor(tag), tf.convert_to_tensor(per_replica_text)]
for tag, per_replica_text in enumerate(text)
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
# pylint: enable=protected-access
val_outputs = distribution_strategy.experimental_local_results(
outputs = distribution_strategy.experimental_local_results(
predict_step(text))
val_outputs = np.reshape(
[val_output.numpy() for val_output in val_outputs],
[params["decode_batch_size"], -1])
tags, unordered_val_outputs = outputs[0]
tags = [tag.numpy() for tag in tags._values]
unordered_val_outputs = [
val_output.numpy() for val_output in unordered_val_outputs._values]
# pylint: enable=protected-access
val_outputs = [None] * len(tags)
for k in range(len(tags)):
val_outputs[tags[k]] = unordered_val_outputs[k]
val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1])
else:
val_outputs, _ = model.predict(text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment