Commit bf748370 authored by Nimit Nigania's avatar Nimit Nigania
Browse files

Merge remote-tracking branch 'upstream/master'

parents 7c732da7 0d2c2e01
...@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj): ...@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params = { data_processing_params = {
"train_epochs": flag_obj.num_train_epochs, "train_epochs": flag_obj.num_train_epochs,
"batch_size": flag_obj.prebatch_size, "batch_size": flag_obj.train_prebatch_size,
"eval_batch_size": flag_obj.prebatch_size, "eval_batch_size": flag_obj.eval_prebatch_size,
"batches_per_step": 1, "batches_per_step": 1,
"stream_files": True, "stream_files": True,
"num_neg": flag_obj.num_negative_samples, "num_neg": flag_obj.num_negative_samples,
......
...@@ -117,7 +117,10 @@ def create_dataset_from_data_producer(producer, params): ...@@ -117,7 +117,10 @@ def create_dataset_from_data_producer(producer, params):
return train_input_dataset, eval_input_dataset return train_input_dataset, eval_input_dataset
def create_ncf_input_data(params, producer=None, input_meta_data=None): def create_ncf_input_data(params,
producer=None,
input_meta_data=None,
strategy=None):
"""Creates NCF training/evaluation dataset. """Creates NCF training/evaluation dataset.
Args: Args:
...@@ -128,6 +131,9 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None): ...@@ -128,6 +131,9 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
input_meta_data: A dictionary of input metadata to be used when reading data input_meta_data: A dictionary of input metadata to be used when reading data
from tf record files. Must be specified when params["train_input_dataset"] from tf record files. Must be specified when params["train_input_dataset"]
is specified. is specified.
strategy: Distribution strategy used for distributed training. If specified,
used to assert that evaluation batch size is correctly a multiple of
total number of devices used.
Returns: Returns:
(training dataset, evaluation dataset, train steps per epoch, (training dataset, evaluation dataset, train steps per epoch,
...@@ -136,6 +142,17 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None): ...@@ -136,6 +142,17 @@ def create_ncf_input_data(params, producer=None, input_meta_data=None):
Raises: Raises:
ValueError: If data is being generated online for when using TPU's. ValueError: If data is being generated online for when using TPU's.
""" """
# NCF evaluation metric calculation logic assumes that evaluation data
# sample size are in multiples of (1 + number of negative samples in
# evaluation) for each device. As so, evaluation batch size must be a
# multiple of (number of replicas * (1 + number of negative samples)).
num_devices = strategy.num_replicas_in_sync if strategy else 1
if (params["eval_batch_size"] % (num_devices *
(1 + rconst.NUM_EVAL_NEGATIVES))):
raise ValueError("Evaluation batch size must be divisible by {} "
"times {}".format(num_devices,
(1 + rconst.NUM_EVAL_NEGATIVES)))
if params["train_dataset_path"]: if params["train_dataset_path"]:
assert params["eval_dataset_path"] assert params["eval_dataset_path"]
......
...@@ -121,7 +121,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -121,7 +121,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
""" """
self._run_and_report_benchmark(hr_at_10_min=0.61) self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640): def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results. """Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have Note: Target is 0.635, but some runs are below that level. Until we have
...@@ -203,6 +203,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -203,6 +203,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
self._setup() self._setup()
FLAGS.early_stopping = True FLAGS.early_stopping = True
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_2_gpus_ctl_early_stop(self): def benchmark_2_gpus_ctl_early_stop(self):
...@@ -211,6 +212,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -211,6 +212,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.keras_use_ctl = True FLAGS.keras_use_ctl = True
FLAGS.early_stopping = True FLAGS.early_stopping = True
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.eval_batch_size = 160000
self._run_and_report_benchmark() self._run_and_report_benchmark()
############################################# #############################################
...@@ -287,6 +289,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -287,6 +289,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.train_epochs = 17 FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576 FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045 FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25 FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5 FLAGS.beta2 = 0.5
...@@ -299,6 +302,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -299,6 +302,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.train_epochs = 17 FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576 FLAGS.batch_size = 1048576
FLAGS.eval_batch_size = 160000
FLAGS.learning_rate = 0.0045 FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25 FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5 FLAGS.beta2 = 0.5
...@@ -306,19 +310,6 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -306,19 +310,6 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.force_v2_in_keras_compile = False FLAGS.force_v2_in_keras_compile = False
self._run_and_report_benchmark_mlperf_like() self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_8_gpu_mlperf_like(self):
"""8 GPU using keras fit/compile with XLA."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.enable_xla = True
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_ctl_mlperf_like(self): def benchmark_8_gpu_ctl_mlperf_like(self):
"""8 GPU using CTL.""" """8 GPU using CTL."""
self._setup() self._setup()
...@@ -326,20 +317,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -326,20 +317,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.num_gpus = 8 FLAGS.num_gpus = 8
FLAGS.train_epochs = 17 FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576 FLAGS.batch_size = 1048576
FLAGS.learning_rate = 0.0045 FLAGS.eval_batch_size = 160000
FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5
FLAGS.epsilon = 1e-8
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_8_gpu_ctl_mlperf_like(self):
"""8 GPU using CTL with XLA."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.num_gpus = 8
FLAGS.train_epochs = 17
FLAGS.batch_size = 1048576
FLAGS.learning_rate = 0.0045 FLAGS.learning_rate = 0.0045
FLAGS.beta1 = 0.25 FLAGS.beta1 = 0.25
FLAGS.beta2 = 0.5 FLAGS.beta2 = 0.5
...@@ -360,6 +338,7 @@ class NCFKerasSynth(NCFKerasBenchmarkBase): ...@@ -360,6 +338,7 @@ class NCFKerasSynth(NCFKerasBenchmarkBase):
default_flags['num_gpus'] = 1 default_flags['num_gpus'] = 1
default_flags['train_epochs'] = 8 default_flags['train_epochs'] = 8
default_flags['batch_size'] = 99000 default_flags['batch_size'] = 99000
default_flags['eval_batch_size'] = 160000
default_flags['learning_rate'] = 0.00382059 default_flags['learning_rate'] = 0.00382059
default_flags['beta1'] = 0.783529 default_flags['beta1'] = 0.783529
default_flags['beta2'] = 0.909003 default_flags['beta2'] = 0.909003
......
...@@ -64,12 +64,20 @@ class MetricLayer(tf.keras.layers.Layer): ...@@ -64,12 +64,20 @@ class MetricLayer(tf.keras.layers.Layer):
def __init__(self, params): def __init__(self, params):
super(MetricLayer, self).__init__() super(MetricLayer, self).__init__()
self.params = params self.params = params
self.metric = tf.keras.metrics.Mean(name=rconst.HR_METRIC_NAME)
def call(self, inputs): def call(self, inputs, training=False):
logits, dup_mask = inputs logits, dup_mask = inputs
in_top_k, metric_weights = metric_fn(logits, dup_mask, self.params)
self.add_metric(self.metric(in_top_k, sample_weight=metric_weights)) if training:
hr_sum = 0.0
hr_count = 0.0
else:
metric, metric_weights = metric_fn(logits, dup_mask, self.params)
hr_sum = tf.reduce_sum(metric * metric_weights)
hr_count = tf.reduce_sum(metric_weights)
self.add_metric(hr_sum, name="hr_sum", aggregation="mean")
self.add_metric(hr_count, name="hr_count", aggregation="mean")
return logits return logits
...@@ -249,7 +257,7 @@ def run_ncf(_): ...@@ -249,7 +257,7 @@ def run_ncf(_):
(train_input_dataset, eval_input_dataset, (train_input_dataset, eval_input_dataset,
num_train_steps, num_eval_steps) = \ num_train_steps, num_eval_steps) = \
(ncf_input_pipeline.create_ncf_input_data( (ncf_input_pipeline.create_ncf_input_data(
params, producer, input_meta_data)) params, producer, input_meta_data, strategy))
steps_per_epoch = None if generate_input_online else num_train_steps steps_per_epoch = None if generate_input_online else num_train_steps
with distribution_utils.get_strategy_scope(strategy): with distribution_utils.get_strategy_scope(strategy):
...@@ -295,11 +303,19 @@ def run_ncf(_): ...@@ -295,11 +303,19 @@ def run_ncf(_):
logging.info("Training done. Start evaluating") logging.info("Training done. Start evaluating")
eval_results = keras_model.evaluate( eval_loss_and_metrics = keras_model.evaluate(
eval_input_dataset, steps=num_eval_steps, verbose=2) eval_input_dataset, steps=num_eval_steps, verbose=2)
logging.info("Keras evaluation is done.") logging.info("Keras evaluation is done.")
# Keras evaluate() API returns scalar loss and metric values from
# evaluation as a list. Here, the returned list would contain
# [evaluation loss, hr sum, hr count].
eval_hit_rate = eval_loss_and_metrics[1] / eval_loss_and_metrics[2]
# Format evaluation result into [eval loss, eval hit accuracy].
eval_results = [eval_loss_and_metrics[0], eval_hit_rate]
if history and history.history: if history and history.history:
train_history = history.history train_history = history.history
train_loss = train_history["loss"][-1] train_loss = train_history["loss"][-1]
......
...@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase): ...@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase):
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase): ...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase):
['-num_gpus', '0'] + ['-num_gpus', '0'] +
['-keras_use_ctl', 'True']) ['-keras_use_ctl', 'True'])
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase): ...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase):
format(1, context.num_gpus())) format(1, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase): ...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase):
format(2, context.num_gpus())) format(2, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2'])
if __name__ == "__main__": if __name__ == "__main__":
......
# ResNet in TensorFlow # ResNet in TensorFlow
* For the Keras version of the ResNet model, see * For the Keras version of the ResNet model, see
[`official/resnet/keras`](keras). [`official/vision/image_classification`](../vision/image_classification).
* For the Keras custom training loop version, see * For the Keras custom training loop version, see
[`official/resnet/ctl`](ctl). [`official/resnet/ctl`](ctl).
* For the Estimator version, see [`official/r1/resnet`](../r1/resnet). * For the Estimator version, see [`official/r1/resnet`](../r1/resnet).
\ No newline at end of file
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -27,3 +27,6 @@ def define_ctl_flags(): ...@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True, flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a ' help='Wrap the train and test step inside a '
'tf.function.') 'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
...@@ -22,7 +22,7 @@ import time ...@@ -22,7 +22,7 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.resnet.ctl import ctl_imagenet_main from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
...@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [ flag_methods = [
ctl_common.define_ctl_flags, ctl_common.define_ctl_flags,
keras_common.define_keras_flags common.define_keras_flags
] ]
self.data_dir = os.path.join(root_data_dir, 'imagenet') self.data_dir = os.path.join(root_data_dir, 'imagenet')
...@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ flag_methods = [
ctl_common.define_ctl_flags, ctl_common.define_ctl_flags,
keras_common.define_keras_flags common.define_keras_flags
] ]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
...@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 64 FLAGS.batch_size = 64
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu(self): def benchmark_8_gpu(self):
......
...@@ -24,10 +24,10 @@ from absl import logging ...@@ -24,10 +24,10 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.resnet.keras import keras_imagenet_main from official.vision.image_classification import resnet_imagenet_main
from official.resnet.keras import resnet_model from official.vision.image_classification import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy): ...@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets.""" """Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS, num_channels=imagenet_preprocessing.NUM_CHANNELS,
...@@ -137,6 +137,10 @@ def run(flags_obj): ...@@ -137,6 +137,10 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
...@@ -163,10 +167,11 @@ def run(flags_obj): ...@@ -163,10 +167,11 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size) dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True) nesterov=True)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
...@@ -175,6 +180,8 @@ def run(flags_obj): ...@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs): def train_step(train_ds_inputs):
"""Training StepFn.""" """Training StepFn."""
def step_fn(inputs): def step_fn(inputs):
...@@ -185,13 +192,22 @@ def run(flags_obj): ...@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits) labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size) loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) / num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2 if flags_obj.single_l2_loss_op:
filtered_variables = [
grads = tape.gradient(loss, model.trainable_variables) tf.reshape(v, (-1,))
optimizer.apply_gradients(zip(grads, model.trainable_variables)) for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits) training_accuracy.update_state(labels, logits)
return loss return loss
...@@ -232,7 +248,7 @@ def run(flags_obj): ...@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy.reset_states() training_accuracy.reset_states()
for step in range(train_steps): for step in range(train_steps):
optimizer.lr = keras_imagenet_main.learning_rate_schedule( optimizer.lr = resnet_imagenet_main.learning_rate_schedule(
epoch, step, train_steps, flags_obj.batch_size) epoch, step, train_steps, flags_obj.batch_size)
time_callback.on_batch_begin(step+epoch*train_steps) time_callback.on_batch_begin(step+epoch*train_steps)
...@@ -281,6 +297,8 @@ def main(_): ...@@ -281,6 +297,8 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
keras_common.define_keras_flags() common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
absl_app.run(main) absl_app.run(main)
...@@ -25,8 +25,8 @@ from tensorflow.python.eager import context ...@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main from official.resnet.ctl import ctl_imagenet_main
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
...@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase): ...@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass() super(CtlImagenetTest, cls).setUpClass()
keras_common.define_keras_flags() common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
def setUp(self): def setUp(self):
......
...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark): ...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods=[shakespeare_main.define_flags]) flag_methods=[shakespeare_main.define_flags])
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_train_min=0.923, top_1_train_min=0.91,
top_1_train_max=0.93, top_1_train_max=0.94,
warmup=1, warmup=1,
log_steps=100): log_steps=100):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
...@@ -208,21 +208,6 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase): ...@@ -208,21 +208,6 @@ class ShakespeareAccuracy(ShakespeareBenchmarkBase):
FLAGS.model_dir = '' FLAGS.model_dir = ''
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu(self):
"""Benchmark 8 gpu w/xla.
This is test is for accuracy not scaling. The batch-size is not scaled to
the number of gpus.
"""
self._setup()
FLAGS.num_gpus = 8
FLAGS.training_data = self.train_data
FLAGS.batch_size = 64
FLAGS.train_epochs = 43
FLAGS.model_dir = ''
FLAGS.enable_xla = True
self._run_and_report_benchmark()
class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase): class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
"""Benchmark accuracy tests.""" """Benchmark accuracy tests."""
......
...@@ -79,8 +79,41 @@ class _StateKeys(object): ...@@ -79,8 +79,41 @@ class _StateKeys(object):
class SequenceBeamSearch(object): class SequenceBeamSearch(object):
"""Implementation of beam search loop.""" """Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, def __init__(self,
beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32): symbols_to_logits_fn,
vocab_size,
batch_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch_size = batch_size self.batch_size = batch_size
...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object): ...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha self.alpha = alpha
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.eos_id = eos_id self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache): def search(self, initial_ids, initial_cache):
...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object): ...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1] # Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2) alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
# Create tensor for storing initial log probabilities. # Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0 # Assume initial_ids are prob 1.0
...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object): ...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may # 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size). # depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations. # 2) the dimension may have different values on different iterations.
state_shape_invariants = { if self.padded_decode:
_StateKeys.CUR_INDEX: tf.TensorShape([]), state_shape_invariants = {
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]), _StateKeys.CUR_INDEX:
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]), tf.TensorShape([]),
_StateKeys.ALIVE_CACHE: nest.map_structure( _StateKeys.ALIVE_SEQ:
_get_shape_keep_last_dim, alive_cache), tf.TensorShape(
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]), [self.batch_size, self.beam_size,
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]), self.max_decode_length + 1]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size]) _StateKeys.ALIVE_LOG_PROBS:
} tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([self.batch_size, self.beam_size])
}
else:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants return state, state_shape_invariants
...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object): ...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new # Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time. # cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] if self.padded_decode:
flat_ids = tf.reshape(
tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
[self.batch_size * self.beam_size, -1])
else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache) flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object): ...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences # Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2) if self.padded_decode:
topk_seq = tf.concat([topk_seq, topk_ids], axis=2) topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object): ...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length. # First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1] # New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat( if not self.padded_decode:
[finished_seq, finished_seq = tf.concat([
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
],
axis=2)
# Calculate new seq scores from log probabilities. # Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype) length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object): ...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def sequence_beam_search( def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id): alpha, max_decode_length, eos_id, padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of the vocabulary, used for topk
max_decode_length: maximum length to decoded sequence computation.
eos_id: int id of eos token, used to determine when a sequence has finished beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id) beam_size, alpha, max_decode_length, eos_id,
padded_decode)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor): ...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list) return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor): def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension. """Reshapes first two dimensions in to single dimension.
......
...@@ -32,6 +32,7 @@ from absl import flags ...@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.r1.utils import export
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer import translate from official.transformer import translate
from official.transformer.model import model_params from official.transformer.model import model_params
...@@ -41,7 +42,6 @@ from official.transformer.utils import metrics ...@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from official.transformer.utils import schedule from official.transformer.utils import schedule
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
...@@ -56,7 +56,7 @@ PARAMS_MAP = { ...@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS = 10 DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9) INF = 1000000000 # 1e9
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item # Dictionary containing tensors that are logged by the logging hooks. Each item
......
...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer): ...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size]) return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None): def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y. """Apply attention mechanism to x and y.
Args: Args:
x: a tensor with shape [batch_size, length_x, hidden_size] x: A tensor with shape [batch_size, length_x, hidden_size].
y: a tensor with shape [batch_size, length_y, hidden_size] y: A tensor with shape [batch_size, length_y, hidden_size].
bias: attention bias that will be added to the result of the dot product. bias: A bool, the attention bias that will be added to the result of the
training: boolean, whether in training mode or not. dot product.
cache: (Used during prediction) dictionary with tensors containing results training: A bool, whether in training mode or not.
of previous attentions. The dictionary must have the items: cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels], {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]} "v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length. where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns: Returns:
Attention layer output with shape [batch_size, length_x, hidden_size] Attention layer output with shape [batch_size, length_x, hidden_size]
""" """
# Linearly project the query (q), key (k) and value (v) using different # Linearly project the query, key and value using different learned
# learned projections. This is in preparation of splitting them into # projections. This is in preparation of splitting them into multiple
# multiple heads. Multi-head attention uses multiple queries, keys, and # heads. Multi-head attention uses multiple queries, keys, and values
# values rather than regular attention (which uses a single q, k, v). # rather than regular attention (which uses a single query, key, value).
q = self.q_dense_layer(x) query = self.q_dense_layer(x)
k = self.k_dense_layer(y) key = self.k_dense_layer(y)
v = self.v_dense_layer(y) value = self.v_dense_layer(y)
if cache is not None: if cache is not None:
# Combine cached keys and values with new keys and values. # Combine cached keys and values with new keys and values.
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1) if decode_loop_step is not None:
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1) cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache # Update cache
cache["k"] = k cache["k"] = key
cache["v"] = v cache["v"] = value
# Split q, k, v into heads. # Split query, key, value into heads.
q = self.split_heads(q) query = self.split_heads(query)
k = self.split_heads(k) key = self.split_heads(key)
v = self.split_heads(v) value = self.split_heads(value)
# Scale q to prevent the dot product between q and k from growing too large. # Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads) depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5 query *= depth ** -0.5
# Calculate dot product attention # Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) logits = tf.matmul(query, key, transpose_b=True)
logits += bias logits += bias
# Note that softmax internally performs math operations using float32 # Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input # for numeric stability. When training with float16, we keep the input
...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights") weights = tf.nn.softmax(logits, name="attention_weights")
if training: if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout) weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v) attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size] # Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output) attention_output = self.combine_heads(attention_output)
...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer): ...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention): class SelfAttention(Attention):
"""Multiheaded self-attention layer.""" """Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None): def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache) return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return finished_seq, finished_scores return finished_seq, finished_scores
def sequence_beam_search( def sequence_beam_search(symbols_to_logits_fn,
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, initial_ids,
alpha, max_decode_length, eos_id, dtype="float32"): initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of tokens.
max_decode_length: maximum length to decoded sequence beam_size: An integer, the number of beams.
eos_id: int id of eos token, used to determine when a sequence has finished, alpha: A float, defining the strength of length normalization.
dtype: The dtype to use. max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
if misc.is_v2(): if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
else: else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params): ...@@ -273,7 +273,7 @@ def _generate_synthetic_data(params):
label_value=1, label_value=1,
label_dtype=tf.int64, label_dtype=tf.int64,
) )
return dataset.batch(batch) return dataset.batch(batch, drop_remainder=True)
def train_input_fn(params): def train_input_fn(params):
......
...@@ -176,6 +176,44 @@ def define_transformer_flags(): ...@@ -176,6 +176,44 @@ def define_transformer_flags():
flags.DEFINE_string( flags.DEFINE_string(
name='mode', default='train', name='mode', default='train',
help=flags_core.help_wrap('mode: train, eval, or predict')) help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='is_tpu_pod',
default=False,
help=flags_core.help_wrap('Whether the model runs on a TPU pod.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
...@@ -216,8 +254,6 @@ def define_transformer_flags(): ...@@ -216,8 +254,6 @@ def define_transformer_flags():
return True return True
# pylint: enable=unused-variable # pylint: enable=unused-variable
flags_core.require_cloud_storage(['data_dir', 'model_dir', 'export_dir'])
def get_callbacks(): def get_callbacks():
"""Returns common callbacks.""" """Returns common callbacks."""
......
...@@ -23,6 +23,51 @@ import tensorflow as tf ...@@ -23,6 +23,51 @@ import tensorflow as tf
K = tf.keras.backend K = tf.keras.backend
class LearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate schedule."""
def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
hidden_size: An integer, the model dimension in the hidden layers.
warmup_steps: An integer, the number of steps required for linear warmup.
"""
super(LearningRateSchedule, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.hidden_size = hidden_size
self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, global_step):
"""Calculate learning rate with linear warmup and rsqrt decay.
Args:
global_step: An integer, the current global step used for learning rate
calculation.
Returns:
A float, the learning rate needs to be used for current global step.
"""
with tf.name_scope('learning_rate_schedule'):
global_step = tf.cast(global_step, tf.float32)
learning_rate = self.initial_learning_rate
learning_rate *= (self.hidden_size**-0.5)
# Apply linear warmup
learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps)
# Apply rsqrt decay
learning_rate /= tf.sqrt(tf.maximum(global_step, self.warmup_steps))
return learning_rate
def get_config(self):
"""Get the configuration of the learning rate schedule."""
return {
'initial_learning_rate': self.initial_learning_rate,
'hidden_size': self.hidden_size,
'warmup_steps': self.warmup_steps,
}
class LearningRateFn(object): class LearningRateFn(object):
"""Creates learning rate function.""" """Creates learning rate function."""
......
...@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model): ...@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length] outputs: [batch_size, decoded length]
scores: [batch_size, float]} scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32. Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] inputs, targets = inputs[0], inputs[1]
else: else:
inputs, targets = inputs[0], None inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems. # Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well. # Other reasonable initializers may also work just as well.
...@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model): ...@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"]) max_decode_length, dtype=self.params["dtype"])
# TODO(b/139770046): Refactor code with better naming of i.
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape [batch_size * ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1] beam_size, i + 1].
i: Loop index i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
...@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model): ...@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal. # Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack( decoder_outputs = self.decoder_stack(
decoder_input, decoder_input,
cache.get("encoder_outputs"), cache.get("encoder_outputs"),
self_attention_bias, self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache.get("encoder_decoder_attention_bias"),
training=training, training=training,
cache=cache) cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
return logits, cache return logits, cache
...@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model): ...@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0] if self.params["padded_decode"]:
input_length = tf.shape(encoder_outputs)[1] batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"] max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"]) self.params["dtype"])
...@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model): ...@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer. # Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
cache = { cache = {
"layer_%d" % layer: { "layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]], "k":
dtype=self.params["dtype"]), tf.zeros([
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]], batch_size, init_decode_length, self.params["hidden_size"]
dtype=self.params["dtype"]) ],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"]) } for layer in range(self.params["num_hidden_layers"])
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
...@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model): ...@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model):
alpha=self.params["alpha"], alpha=self.params["alpha"],
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=EOS_ID, eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"]) dtype=self.params["dtype"])
# Get the top sequence for each batch element # Get the top sequence for each batch element
...@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias, decoder_self_attention_bias,
attention_bias, attention_bias,
training, training,
cache=None): cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size] decoder_inputs: A tensor with shape
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size] [batch_size, target_length, hidden_size].
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1, encoder_outputs: A tensor with shape
target_len, target_length] [batch_size, input_length, hidden_size]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1, decoder_self_attention_bias: A tensor with shape
1, input_length] [1, 1, target_len, target_length], the bias for decoder self-attention
training: boolean, whether in training mode or not. layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels], {layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}, "v": A tensor with shape [batch_size, i, value_channels]},
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns: Returns:
Output of decoder layer stack. Output of decoder layer stack.
...@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs, decoder_inputs,
decoder_self_attention_bias, decoder_self_attention_bias,
training=training, training=training,
cache=layer_cache) cache=layer_cache,
decode_loop_step=decode_loop_step)
with tf.name_scope("encdec_attention"): with tf.name_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer( decoder_inputs = enc_dec_attention_layer(
decoder_inputs, decoder_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment