Commit bf748370 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 7c732da7 0d2c2e01
......@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params = {
"train_epochs": flag_obj.num_train_epochs,
"batch_size": flag_obj.prebatch_size,
"eval_batch_size": flag_obj.prebatch_size,
"batch_size": flag_obj.train_prebatch_size,
"eval_batch_size": flag_obj.eval_prebatch_size,
"batches_per_step": 1,
"stream_files": True,
"num_neg": flag_obj.num_negative_samples,
......
......@@ -117,7 +117,10 @@ def create_dataset_from_data_producer(producer, params):
return train_input_dataset, eval_input_dataset
def create_ncf_input_data(params, producer=None, input_meta_data=None):
def create_ncf_input_data(params,
producer=None,
input_meta_data=None,
strategy=None):
"""Creates NCF training/evaluation dataset.
Args:
......@@ -128,6 +131,9 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
input_meta_data: A dictionary of input metadata to be used when reading data
from tf record files. Must be specified when params["train_input_dataset"]
is specified.
strategy: Distribution strategy used for distributed training. If specified,
used to assert that evaluation batch size is correctly a multiple of
total number of devices used.
Returns:
(training dataset, evaluation dataset, train steps per epoch,
......@@ -136,6 +142,17 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
Raises:
ValueError: If data is being generated online for when using TPU's.
"""
# NCF evaluation metric calculation logic assumes that evaluation data
# sample size are in multiples of (1 + number of negative samples in
# evaluation) for each device. As so, evaluation batch size must be a
# multiple of (number of replicas * (1 + number of negative samples)).
num_devices = strategy.num_replicas_in_sync if strategy else 1
if (params["eval_batch_size"] % (num_devices *
(1 + rconst.NUM_EVAL_NEGATIVES))):
raise ValueError("Evaluation batch size must be divisible by {} "
"times {}".format(num_devices,
(1 + rconst.NUM_EVAL_NEGATIVES)))
if params["train_dataset_path"]:
assert params["eval_dataset_path"]
......
......@@ -121,7 +121,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
"""
self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640):
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have
......@@ -203,6 +203,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self._setup()
FLAGS.early_stopping = True
FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark()
def benchmark_2_gpus_ctl_early_stop(self):
......@@ -211,6 +212,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True
FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark()
#############################################
......@@ -287,6 +289,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
......@@ -299,6 +302,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
......@@ -306,19 +310,6 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.force_v2_in_keras_compile = False
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_8_gpu_mlperf_like(self):
"""8 GPU using keras fit/compile with XLA."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_ctl_mlperf_like(self):
"""8 GPU using CTL."""
self._setup()
......@@ -326,20 +317,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_8_gpu_ctl_mlperf_like(self):
"""8 GPU using CTL with XLA."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
......@@ -360,6 +338,7 @@ class NCFKerasSynth(NCFKerasBenchmarkBase):
default_flags['num_gpus'] = 1
default_flags['train_epochs'] = 8
default_flags['batch_size'] = 99000
default_flags['eval_batch_size'] = 160000
default_flags['learning_rate'] = 0.00382059
default_flags['beta1'] = 0.783529
default_flags['beta2'] = 0.909003
......
......@@ -64,12 +64,20 @@ class MetricLayer(tf.keras.layers.Layer):
def __init__(self, params):
super(MetricLayer, self).__init__()
self.params = params
self.metric = tf.keras.metrics.Mean(name=rconst.HR_METRIC_NAME)
def call(self, inputs):
def call(self, inputs, training=False):
logits, dup_mask = inputs
in_top_k, metric_weights = metric_fn(logits, dup_mask, self.params)
self.add_metric(self.metric(in_top_k, sample_weight=metric_weights))
if training:
hr_sum = 0.0
hr_count = 0.0
else:
metric, metric_weights = metric_fn(logits, dup_mask, self.params)
hr_sum = tf.reduce_sum(metric * metric_weights)
hr_count = tf.reduce_sum(metric_weights)
self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
self.add_metric(hr_count, name="hr_count", aggregation="mean")
return logits
......@@ -249,7 +257,7 @@ def run_ncf(_):
(train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data))
params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy):
......@@ -295,11 +303,19 @@ def run_ncf(_):
logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate(
eval_loss_and_metrics = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.")
# Keras evaluate() API returns scalar loss and metric values from
# evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
# Format evaluation result into [eval loss, eval hit accuracy].
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history:
train_history = history.history
train_loss = train_history["loss"][-1]
......
......@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
......@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase):
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase):
['-num_gpus', '0'] +
['-keras_use_ctl', 'True'])
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase):
format(1, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase):
format(2, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2'])
if __name__ == "__main__":
......
# ResNet in TensorFlow
* For the Keras version of the ResNet model, see
[`official/resnet/keras`](keras).
[`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
\ No newline at end of file
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
......@@ -22,7 +22,7 @@ import time
from absl import flags
import tensorflow as tf
from official.resnet.keras import keras_common
from official.vision.image_classification import common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [
ctl_common.define_ctl_flags,
keras_common.define_keras_flags
common.define_keras_flags
]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
......@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
ctl_common.define_ctl_flags,
keras_common.define_keras_flags
common.define_keras_flags
]
super(Resnet50CtlBenchmarkBase, self).__init__(
......@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 64
FLAGS.use_tf_function = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
......
......@@ -24,10 +24,10 @@ from absl import logging
import tensorflow as tf
from official.resnet.ctl import ctl_common
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.resnet.keras import resnet_model
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_imagenet_main
from official.vision.image_classification import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
......@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras.
......@@ -163,10 +167,11 @@ def run(flags_obj):
with strategy_scope:
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size)
dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9,
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs):
"""Training StepFn."""
def step_fn(inputs):
......@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) /
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits)
return loss
......@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy.reset_states()
for step in range(train_steps):
optimizer.lr = keras_imagenet_main.learning_rate_schedule(
optimizer.lr = resnet_imagenet_main.learning_rate_schedule(
epoch, step, train_steps, flags_obj.batch_size)
time_callback.on_batch_begin(step+epoch*train_steps)
......@@ -281,6 +297,8 @@ def main(_):
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
keras_common.define_keras_flags()
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main)
......@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.utils.misc import keras_utils
from official.utils.testing import integration
......@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
keras_common.define_keras_flags()
common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
......
......@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods=[shakespeare_main.define_flags])
def _run_and_report_benchmark(self,
top_1_train_min=0.923,
top_1_train_max=0.93,
top_1_train_min=0.91,
top_1_train_max=0.94,
warmup=1,
log_steps=100):
"""Report benchmark results by writing to local protobuf file.
......@@ -208,21 +208,6 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase):
FLAGS.model_dir = ''
self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla.
This is test is for accuracy not scaling. The batch-size is not scaled to
the number of gpus.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.enable_xla = True
self._run_and_report_benchmark()
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests."""
......
......@@ -79,8 +79,41 @@ class _StateKeys(object):
class SequenceBeamSearch(object):
"""Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32):
def __init__(self,
symbols_to_logits_fn,
vocab_size,
batch_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
self.batch_size = batch_size
......@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha
self.max_decode_length = max_decode_length
self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache):
......@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
......@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
state_shape_invariants = {
_StateKeys.CUR_INDEX: tf.TensorShape([]),
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE: nest.map_structure(
_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size])
}
if self.padded_decode:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([self.batch_size, self.beam_size])
}
else:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants
......@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
if self.padded_decode:
flat_ids = tf.reshape(
tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
[self.batch_size * self.beam_size, -1])
else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
......@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
if self.padded_decode:
topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
......@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat(
[finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
if not self.padded_decode:
finished_seq = tf.concat([
finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
],
axis=2)
# Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
......@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id):
alpha, max_decode_length, eos_id, padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id)
beam_size, alpha, max_decode_length, eos_id,
padded_decode)
return sbs.search(initial_ids, initial_cache)
......@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension.
......
......@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.r1.utils import export
from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.model import model_params
......@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from official.transformer.utils import schedule
from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
......@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9)
INF = 1000000000 # 1e9
BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item
......
......@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None):
def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y.
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
training: boolean, whether in training mode or not.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
x: A tensor with shape [batch_size, length_x, hidden_size].
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: A bool, the attention bias that will be added to the result of the
dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).
q = self.q_dense_layer(x)
k = self.k_dense_layer(y)
v = self.v_dense_layer(y)
# Linearly project the query, key and value using different learned
# projections. This is in preparation of splitting them into multiple
# heads. Multi-head attention uses multiple queries, keys, and values
# rather than regular attention (which uses a single query, key, value).
query = self.q_dense_layer(x)
key = self.k_dense_layer(y)
value = self.v_dense_layer(y)
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1)
if decode_loop_step is not None:
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache
cache["k"] = k
cache["v"] = v
cache["k"] = key
cache["v"] = value
# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# Split query, key, value into heads.
query = self.split_heads(query)
key = self.split_heads(key)
value = self.split_heads(value)
# Scale q to prevent the dot product between q and k from growing too large.
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
query *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits = tf.matmul(query, key, transpose_b=True)
logits += bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
......@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v)
attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
......@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None):
return super(SelfAttention, self).call(x, x, bias, training, cache)
def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
......@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return finished_seq, finished_scores
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id, dtype="float32"):
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished,
dtype: The dtype to use.
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
dtype)
padded_decode, dtype)
else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
dtype)
padded_decode, dtype)
return sbs.search(initial_ids, initial_cache)
......
......@@ -273,7 +273,7 @@ def _generate_synthetic_data(params):
label_value=1,
label_dtype=tf.int64,
)
return dataset.batch(batch)
return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params):
......
......@@ -176,6 +176,44 @@ def define_transformer_flags():
flags.DEFINE_string(
name='mode', default='train',
help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
......@@ -216,8 +254,6 @@ def define_transformer_flags():
return True
# pylint: enable=unused-variable
flags_core.require_cloud_storage(['data_dir', 'model_dir', 'export_dir'])
def get_callbacks():
"""Returns common callbacks."""
......
......@@ -23,6 +23,51 @@ import tensorflow as tf
K = tf.keras.backend
class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule."""
def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
hidden_size: An integer, the model dimension in the hidden layers.
warmup_steps: An integer, the number of steps required for linear warmup.
"""
super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay.
Args:
global_step: An integer, the current global step used for learning rate
calculation.
Returns:
A float, the learning rate needs to be used for current global step.
"""
with tf.name_scope('learning_rate_schedule'):
global_step = tf.cast(global_step, tf.float32)
learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
# Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
'initial_learning_rate': self.initial_learning_rate,
'hidden_size': self.hidden_size,
'warmup_steps': self.warmup_steps,
}
class LearningRateFn(object):
"""Creates learning rate function."""
......
......@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
else:
inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
......@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"])
# TODO(b/139770046): Refactor code with better naming of i.
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
i: Loop index
beam_size, i + 1].
i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
......@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack(
decoder_input,
cache.get("encoder_outputs"),
self_attention_bias,
cache.get("encoder_decoder_attention_bias"),
training=training,
cache=cache)
cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
logits = tf.squeeze(logits, axis=[1])
return logits, cache
......@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
if self.params["padded_decode"]:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"])
......@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]],
dtype=self.params["dtype"]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]],
dtype=self.params["dtype"])
"k":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
}
# pylint: enable=g-complex-comprehension
......@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model):
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"])
# Get the top sequence for each batch element
......@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias,
attention_bias,
training,
cache=None):
cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1,
target_len, target_length]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1,
1, input_length]
training: boolean, whether in training mode or not.
decoder_inputs: A tensor with shape
[batch_size, target_length, hidden_size].
encoder_outputs: A tensor with shape
[batch_size, input_length, hidden_size]
decoder_self_attention_bias: A tensor with shape
[1, 1, target_len, target_length], the bias for decoder self-attention
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
{layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
"v": A tensor with shape [batch_size, i, value_channels]},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder layer stack.
......@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs,
decoder_self_attention_bias,
training=training,
cache=layer_cache)
cache=layer_cache,
decode_loop_step=decode_loop_step)
with tf.name_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer(
decoder_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment