Commit e170a8ba authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 266413847
parent 765da424
...@@ -402,7 +402,7 @@ class SequenceBeamSearch(object): ...@@ -402,7 +402,7 @@ class SequenceBeamSearch(object):
topk_ids = topk_indices % self.vocab_size topk_ids = topk_indices % self.vocab_size
if self.padded_decode: if self.padded_decode:
topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1]) topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids) topk_seq = tf.tensor_scatter_nd_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0]) topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else: else:
topk_ids = tf.expand_dims(topk_ids, axis=2) topk_ids = tf.expand_dims(topk_ids, axis=2)
......
...@@ -54,6 +54,7 @@ from __future__ import print_function ...@@ -54,6 +54,7 @@ from __future__ import print_function
import math import math
import os import os
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import misc from official.transformer.v2 import misc
...@@ -193,7 +194,7 @@ def _batch_examples(dataset, batch_size, max_length): ...@@ -193,7 +194,7 @@ def _batch_examples(dataset, batch_size, max_length):
def _read_and_batch_from_files( def _read_and_batch_from_files(
file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat, file_pattern, batch_size, max_length, num_parallel_calls, shuffle, repeat,
static_batch=False, num_replicas=1): static_batch=False, num_replicas=1, ctx=None):
"""Create dataset where each item is a dict of "inputs" and "targets". """Create dataset where each item is a dict of "inputs" and "targets".
Args: Args:
...@@ -219,12 +220,17 @@ def _read_and_batch_from_files( ...@@ -219,12 +220,17 @@ def _read_and_batch_from_files(
batches, and each global batch is equally divisible by number of replicas. batches, and each global batch is equally divisible by number of replicas.
Currently it is only effective when static_batch==True. TODO: make it Currently it is only effective when static_batch==True. TODO: make it
effective when static_batch=False. effective when static_batch=False.
ctx: Input context.
Returns: Returns:
tf.data.Dataset object containing examples loaded from the files. tf.data.Dataset object containing examples loaded from the files.
""" """
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle) dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
if ctx and ctx.num_input_pipelines > 1:
logging.info("Shard %d of the dataset.", ctx.input_pipeline_id)
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
# Read files and interleave results. When training, the order of the examples # Read files and interleave results. When training, the order of the examples
# will be non-deterministic. # will be non-deterministic.
options = tf.data.Options() options = tf.data.Options()
...@@ -247,7 +253,7 @@ def _read_and_batch_from_files( ...@@ -247,7 +253,7 @@ def _read_and_batch_from_files(
# First calculate batch size (token number) per worker, then divide it # First calculate batch size (token number) per worker, then divide it
# into sentences, and finally expand to a global batch. It could prove # into sentences, and finally expand to a global batch. It could prove
# the global batch divisble for distribution strategy. # the global batch divisble for distribution strategy.
((batch_size // num_replicas) // max_length) * num_replicas, int(batch_size // num_replicas // max_length * num_replicas),
([max_length], [max_length]), drop_remainder=True) ([max_length], [max_length]), drop_remainder=True)
else: else:
# Group and batch such that each batch has examples of similar length. # Group and batch such that each batch has examples of similar length.
...@@ -276,7 +282,7 @@ def _generate_synthetic_data(params): ...@@ -276,7 +282,7 @@ def _generate_synthetic_data(params):
return dataset.batch(batch, drop_remainder=True) return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params): def train_input_fn(params, ctx=None):
"""Load and return dataset of batched examples for use during training.""" """Load and return dataset of batched examples for use during training."""
file_pattern = os.path.join(params["data_dir"] or "", "*train*") file_pattern = os.path.join(params["data_dir"] or "", "*train*")
if params["use_synthetic_data"]: if params["use_synthetic_data"]:
...@@ -285,10 +291,10 @@ def train_input_fn(params): ...@@ -285,10 +291,10 @@ def train_input_fn(params):
file_pattern, params["batch_size"], params["max_length"], file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=True, params["num_parallel_calls"], shuffle=True,
repeat=params["repeat_dataset"], static_batch=params["static_batch"], repeat=params["repeat_dataset"], static_batch=params["static_batch"],
num_replicas=params["num_gpus"]) num_replicas=params["num_gpus"], ctx=ctx)
def eval_input_fn(params): def eval_input_fn(params, ctx=None):
"""Load and return dataset of batched examples for use during evaluation.""" """Load and return dataset of batched examples for use during evaluation."""
file_pattern = os.path.join(params["data_dir"] or "", "*dev*") file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
if params["use_synthetic_data"]: if params["use_synthetic_data"]:
...@@ -296,7 +302,8 @@ def eval_input_fn(params): ...@@ -296,7 +302,8 @@ def eval_input_fn(params):
return _read_and_batch_from_files( return _read_and_batch_from_files(
file_pattern, params["batch_size"], params["max_length"], file_pattern, params["batch_size"], params["max_length"],
params["num_parallel_calls"], shuffle=False, repeat=1, params["num_parallel_calls"], shuffle=False, repeat=1,
static_batch=params["static_batch"], num_replicas=params["num_gpus"]) static_batch=params["static_batch"], num_replicas=params["num_gpus"],
ctx=ctx)
def map_data_for_transformer_fn(x, y): def map_data_for_transformer_fn(x, y):
......
...@@ -182,10 +182,6 @@ def define_transformer_flags(): ...@@ -182,10 +182,6 @@ def define_transformer_flags():
default=False, default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the model runs with custom training loop.')) 'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool( flags.DEFINE_bool(
name='use_tpu_2vm_config', name='use_tpu_2vm_config',
default=False, default=False,
......
...@@ -146,7 +146,6 @@ class TransformerTask(object): ...@@ -146,7 +146,6 @@ class TransformerTask(object):
params["num_gpus"] = num_gpus params["num_gpus"] = num_gpus
params["use_ctl"] = flags_obj.use_ctl params["use_ctl"] = flags_obj.use_ctl
params["is_tpu_pod"] = flags_obj.is_tpu_pod
params["data_dir"] = flags_obj.data_dir params["data_dir"] = flags_obj.data_dir
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
...@@ -210,6 +209,15 @@ class TransformerTask(object): ...@@ -210,6 +209,15 @@ class TransformerTask(object):
with distribution_utils.get_strategy_scope(self.distribution_strategy): with distribution_utils.get_strategy_scope(self.distribution_strategy):
model = transformer.create_model(params, is_train=True) model = transformer.create_model(params, is_train=True)
opt = self._create_optimizer() opt = self._create_optimizer()
current_step = 0
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
current_step = opt.iterations.numpy()
if params["use_ctl"]: if params["use_ctl"]:
train_loss_metric = tf.keras.metrics.Mean( train_loss_metric = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32) "training_loss", dtype=tf.float32)
...@@ -226,7 +234,7 @@ class TransformerTask(object): ...@@ -226,7 +234,7 @@ class TransformerTask(object):
train_ds = ( train_ds = (
self.distribution_strategy self.distribution_strategy
.experimental_distribute_datasets_from_function( .experimental_distribute_datasets_from_function(
lambda ctx: data_pipeline.train_input_fn(params))) lambda ctx: data_pipeline.train_input_fn(params, ctx)))
else: else:
train_ds = data_pipeline.train_input_fn(params) train_ds = data_pipeline.train_input_fn(params)
map_data_fn = data_pipeline.map_data_for_transformer_fn map_data_fn = data_pipeline.map_data_for_transformer_fn
...@@ -275,40 +283,33 @@ class TransformerTask(object): ...@@ -275,40 +283,33 @@ class TransformerTask(object):
self.distribution_strategy.experimental_run_v2( self.distribution_strategy.experimental_run_v2(
_step_fn, args=(next(iterator),)) _step_fn, args=(next(iterator),))
if self.use_tpu:
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info("Loaded checkpoint %s", latest_checkpoint)
if flags_obj.train_steps < flags_obj.steps_between_evals:
flags_obj.steps_between_evals = flags_obj.train_steps
iterations = flags_obj.train_steps // flags_obj.steps_between_evals
cased_score, uncased_score = None, None cased_score, uncased_score = None, None
cased_score_history, uncased_score_history = [], [] cased_score_history, uncased_score_history = [], []
for i in range(1, iterations + 1): while current_step < flags_obj.train_steps:
print("Start train iteration:{}/{}".format(i, iterations)) remaining_steps = flags_obj.train_steps - current_step
train_steps_per_eval = (
remaining_steps if remaining_steps < flags_obj.steps_between_evals
else flags_obj.steps_between_evals)
current_iteration = current_step // flags_obj.steps_between_evals
print("Start train iteration at global step:{}".format(current_step))
history = None history = None
if params["use_ctl"]: if params["use_ctl"]:
if not self.use_tpu: if not self.use_tpu:
raise NotImplementedError( raise NotImplementedError(
"Custom training loop on GPUs is not implemented.") "Custom training loop on GPUs is not implemented.")
train_steps_per_eval = tf.convert_to_tensor(
flags_obj.steps_between_evals, dtype=tf.int32)
# Runs training steps. # Runs training steps.
train_steps(train_ds_iterator, train_steps_per_eval) train_steps(train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
current_step += train_steps_per_eval
train_loss = train_loss_metric.result().numpy().astype(float) train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s", logging.info("Train Step: %d/%d / loss = %s",
i * flags_obj.steps_between_evals, flags_obj.train_steps, current_step, flags_obj.train_steps, train_loss)
train_loss)
checkpoint_name = checkpoint.save( checkpoint_name = checkpoint.save(
os.path.join( os.path.join(
flags_obj.model_dir, flags_obj.model_dir,
"ctl_step_{}.ckpt".format(i * flags_obj.steps_between_evals))) "ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name) logging.info("Saved checkpoint to %s", checkpoint_name)
else: else:
if self.use_tpu: if self.use_tpu:
...@@ -316,24 +317,22 @@ class TransformerTask(object): ...@@ -316,24 +317,22 @@ class TransformerTask(object):
"Keras model.fit on TPUs is not implemented.") "Keras model.fit on TPUs is not implemented.")
history = model.fit( history = model.fit(
train_ds, train_ds,
initial_epoch=i - 1, initial_epoch=current_iteration,
epochs=i, epochs=current_iteration + 1,
steps_per_epoch=flags_obj.steps_between_evals, steps_per_epoch=train_steps_per_eval,
callbacks=callbacks, callbacks=callbacks,
# If TimeHistory is enabled, progress bar would be messy. Increase # If TimeHistory is enabled, progress bar would be messy. Increase
# the verbose level to get rid of it. # the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1)) verbose=(2 if flags_obj.enable_time_history else 1))
current_step += train_steps_per_eval
logging.info("Train history: {}".format(history.history)) logging.info("Train history: {}".format(history.history))
print("End train iteration:{}/{} global step:{}".format( print("End train iteration at global step:{}".format(current_step))
i,
iterations,
i*flags_obj.steps_between_evals))
if (flags_obj.bleu_source and flags_obj.bleu_ref): if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval() uncased_score, cased_score = self.eval()
cased_score_history.append([i, cased_score]) cased_score_history.append([current_iteration + 1, cased_score])
uncased_score_history.append([i, uncased_score]) uncased_score_history.append([current_iteration + 1, uncased_score])
stats = ({ stats = ({
"loss": train_loss "loss": train_loss
...@@ -347,12 +346,13 @@ class TransformerTask(object): ...@@ -347,12 +346,13 @@ class TransformerTask(object):
def eval(self): def eval(self):
"""Evaluates the model.""" """Evaluates the model."""
if not self.predict_model: with distribution_utils.get_strategy_scope(self.distribution_strategy):
self.predict_model = transformer.create_model(self.params, False) if not self.predict_model:
self._load_weights_if_possible( self.predict_model = transformer.create_model(self.params, False)
self.predict_model, self._load_weights_if_possible(
tf.train.latest_checkpoint(self.flags_obj.model_dir)) self.predict_model,
self.predict_model.summary() tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary()
return evaluate_and_log_bleu( return evaluate_and_log_bleu(
self.predict_model, self.params, self.flags_obj.bleu_source, self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.vocab_file, self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
...@@ -430,7 +430,7 @@ class TransformerTask(object): ...@@ -430,7 +430,7 @@ class TransformerTask(object):
# which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite # which will ensure tf.keras.mixed_precision and tf.train.experimental.enable_mixed_precision_graph_rewrite
# do not double up. # do not double up.
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt) opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
return opt return opt
......
...@@ -128,8 +128,10 @@ def translate_file(model, ...@@ -128,8 +128,10 @@ def translate_file(model,
def _step_fn(inputs): def _step_fn(inputs):
"""Per replica step function.""" """Per replica step function."""
val_outputs, _ = model([inputs], training=False) tag = inputs[0]
return val_outputs val_inputs = inputs[1]
val_outputs, _ = model([val_inputs], training=False)
return tag, val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,)) return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
...@@ -140,17 +142,25 @@ def translate_file(model, ...@@ -140,17 +142,25 @@ def translate_file(model,
for i, text in enumerate(input_generator()): for i, text in enumerate(input_generator()):
if distribution_strategy: if distribution_strategy:
text = np.reshape(text, [num_replicas, local_batch_size, -1]) text = np.reshape(text, [num_replicas, local_batch_size, -1])
# Add tag to the input of each replica with the reordering logic after
# outputs, to ensure the output order matches the input order.
text = [ text = [
tf.convert_to_tensor(per_replica_text) for per_replica_text in text [tf.convert_to_tensor(tag), tf.convert_to_tensor(per_replica_text)]
for tag, per_replica_text in enumerate(text)
] ]
# pylint: disable=protected-access # pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text) text = values.PerReplica(distribution_strategy.extended._device_map, text)
# pylint: enable=protected-access outputs = distribution_strategy.experimental_local_results(
val_outputs = distribution_strategy.experimental_local_results(
predict_step(text)) predict_step(text))
val_outputs = np.reshape( tags, unordered_val_outputs = outputs[0]
[val_output.numpy() for val_output in val_outputs], tags = [tag.numpy() for tag in tags._values]
[params["decode_batch_size"], -1]) unordered_val_outputs = [
val_output.numpy() for val_output in unordered_val_outputs._values]
# pylint: enable=protected-access
val_outputs = [None] * len(tags)
for k in range(len(tags)):
val_outputs[tags[k]] = unordered_val_outputs[k]
val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1])
else: else:
val_outputs, _ = model.predict(text) val_outputs, _ = model.predict(text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment