Unverified Commit a35e09d2 authored by Vinh Nguyen's avatar Vinh Nguyen Committed by GitHub
Browse files

Merge branch 'master' into amp_resnet50

parents d5722dcd 1f5a5e9d
...@@ -20,7 +20,7 @@ from __future__ import print_function ...@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.export import export from official.r1.utils import export
class ExportUtilsTest(tf.test.TestCase): class ExportUtilsTest(tf.test.TestCase):
......
...@@ -29,7 +29,7 @@ import tensorflow as tf ...@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.datasets import movielens from official.datasets import movielens
from official.utils.data import file_io from official.r1.utils.data import file_io
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
...@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj): ...@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params = { data_processing_params = {
"train_epochs": flag_obj.num_train_epochs, "train_epochs": flag_obj.num_train_epochs,
"batch_size": flag_obj.prebatch_size, "batch_size": flag_obj.train_prebatch_size,
"eval_batch_size": flag_obj.prebatch_size, "eval_batch_size": flag_obj.eval_prebatch_size,
"batches_per_step": 1, "batches_per_step": 1,
"stream_files": True, "stream_files": True,
"num_neg": flag_obj.num_negative_samples, "num_neg": flag_obj.num_negative_samples,
......
...@@ -154,8 +154,10 @@ def define_ncf_flags(): ...@@ -154,8 +154,10 @@ def define_ncf_flags():
intra_op=False, intra_op=False,
synthetic_data=True, synthetic_data=True,
max_train_steps=False, max_train_steps=False,
dtype=False, dtype=True,
all_reduce_alg=False, all_reduce_alg=False,
loss_scale=True,
dynamic_loss_scale=True,
enable_xla=True, enable_xla=True,
force_v2_in_keras_compile=True force_v2_in_keras_compile=True
) )
......
...@@ -21,7 +21,6 @@ from __future__ import print_function ...@@ -21,7 +21,6 @@ from __future__ import print_function
import functools import functools
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import numpy as np
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
...@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
def make_dataset(files_dataset, shard_index): def make_dataset(files_dataset, shard_index):
"""Returns dataset for sharded tf record files.""" """Returns dataset for sharded tf record files."""
if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size))
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index) files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave(tf.data.TFRecordDataset) dataset = files_dataset.interleave(tf.data.TFRecordDataset)
decode_fn = functools.partial( decode_fn = functools.partial(
...@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern, ...@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training=is_training) is_training=is_training)
dataset = dataset.map( dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS) dataset = tf.data.Dataset.range(NUM_SHARDS)
......
...@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
""" """
self._run_and_report_benchmark(hr_at_10_min=0.61) self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640): def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results. """Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have Note: Target is 0.635, but some runs are below that level. Until we have
...@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7 FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like() self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self): def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
"""1 GPU using CTL with eager and distribution strategy.""" """1 GPU using CTL with eager and distribution strategy."""
self._setup() self._setup()
...@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase): ...@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7 FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like() self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL with XLA."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_mlperf_like(self): def benchmark_8_gpu_mlperf_like(self):
"""8 GPU using keras fit/compile.""" """8 GPU using keras fit/compile."""
self._setup() self._setup()
......
...@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper ...@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -267,6 +268,12 @@ def run_ncf(_): ...@@ -267,6 +268,12 @@ def run_ncf(_):
beta_1=params["beta1"], beta_1=params["beta1"],
beta_2=params["beta2"], beta_2=params["beta2"],
epsilon=params["epsilon"]) epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]: if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training( train_loss, eval_results = run_ncf_custom_training(
...@@ -371,8 +378,12 @@ def run_ncf_custom_training(params, ...@@ -371,8 +378,12 @@ def run_ncf_custom_training(params,
softmax_logits, softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK]) sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / params["batch_size"]) loss *= (1.0 / params["batch_size"])
if FLAGS.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, keras_model.trainable_variables) grads = tape.gradient(loss, keras_model.trainable_variables)
if FLAGS.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
# Converting gradients to dense form helps in perf on GPU for NCF # Converting gradients to dense form helps in perf on GPU for NCF
grads = neumf_model.sparse_to_dense_grads( grads = neumf_model.sparse_to_dense_grads(
list(zip(grads, keras_model.trainable_variables))) list(zip(grads, keras_model.trainable_variables)))
......
...@@ -27,3 +27,6 @@ def define_ctl_flags(): ...@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True, flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a ' help='Wrap the train and test step inside a '
'tf.function.') 'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
...@@ -22,7 +22,7 @@ import time ...@@ -22,7 +22,7 @@ import time
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.resnet.ctl import ctl_imagenet_main from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
...@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark): ...@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [ flag_methods = [
ctl_common.define_ctl_flags, ctl_common.define_ctl_flags,
keras_common.define_keras_flags common.define_keras_flags
] ]
self.data_dir = os.path.join(root_data_dir, 'imagenet') self.data_dir = os.path.join(root_data_dir, 'imagenet')
...@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def __init__(self, output_dir=None, default_flags=None): def __init__(self, output_dir=None, default_flags=None):
flag_methods = [ flag_methods = [
ctl_common.define_ctl_flags, ctl_common.define_ctl_flags,
keras_common.define_keras_flags common.define_keras_flags
] ]
super(Resnet50CtlBenchmarkBase, self).__init__( super(Resnet50CtlBenchmarkBase, self).__init__(
...@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark): ...@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager') FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 64 FLAGS.batch_size = 64
FLAGS.use_tf_function = False FLAGS.use_tf_function = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu(self): def benchmark_8_gpu(self):
......
...@@ -24,10 +24,10 @@ from absl import logging ...@@ -24,10 +24,10 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.resnet.keras import keras_imagenet_main from official.vision.image_classification import resnet_imagenet_main
from official.resnet.keras import resnet_model from official.vision.image_classification import resnet_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy): ...@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets.""" """Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS, num_channels=imagenet_preprocessing.NUM_CHANNELS,
...@@ -137,6 +137,10 @@ def run(flags_obj): ...@@ -137,6 +137,10 @@ def run(flags_obj):
Returns: Returns:
Dictionary of training and eval stats. Dictionary of training and eval stats.
""" """
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras. # TODO(anj-s): Set data_format without using Keras.
...@@ -163,10 +167,11 @@ def run(flags_obj): ...@@ -163,10 +167,11 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size) dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9, learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True) nesterov=True)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
...@@ -175,6 +180,8 @@ def run(flags_obj): ...@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32) 'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs): def train_step(train_ds_inputs):
"""Training StepFn.""" """Training StepFn."""
def step_fn(inputs): def step_fn(inputs):
...@@ -185,13 +192,22 @@ def run(flags_obj): ...@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits) labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size) loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) / num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2 if flags_obj.single_l2_loss_op:
filtered_variables = [
grads = tape.gradient(loss, model.trainable_variables) tf.reshape(v, (-1,))
optimizer.apply_gradients(zip(grads, model.trainable_variables)) for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits) training_accuracy.update_state(labels, logits)
return loss return loss
...@@ -232,7 +248,7 @@ def run(flags_obj): ...@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy.reset_states() training_accuracy.reset_states()
for step in range(train_steps): for step in range(train_steps):
optimizer.lr = keras_imagenet_main.learning_rate_schedule( optimizer.lr = resnet_imagenet_main.learning_rate_schedule(
epoch, step, train_steps, flags_obj.batch_size) epoch, step, train_steps, flags_obj.batch_size)
time_callback.on_batch_begin(step+epoch*train_steps) time_callback.on_batch_begin(step+epoch*train_steps)
...@@ -281,7 +297,7 @@ def main(_): ...@@ -281,7 +297,7 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
keras_common.define_keras_flags() common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common) flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common) flags.adopt_module_key_flags(ctl_common)
......
...@@ -25,8 +25,8 @@ from tensorflow.python.eager import context ...@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main from official.resnet.ctl import ctl_imagenet_main
from official.resnet.keras import imagenet_preprocessing from official.vision.image_classification import imagenet_preprocessing
from official.resnet.keras import keras_common from official.vision.image_classification import common
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
...@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase): ...@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@classmethod @classmethod
def setUpClass(cls): # pylint: disable=invalid-name def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass() super(CtlImagenetTest, cls).setUpClass()
keras_common.define_keras_flags() common.define_keras_flags()
ctl_common.define_ctl_flags() ctl_common.define_ctl_flags()
def setUp(self): def setUp(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common as keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_cifar_main as keras_cifar_main
from official.vision.image_classification import resnet_cifar_model
from official.vision.image_classification import resnet_imagenet_main as keras_imagenet_main
from official.vision.image_classification import resnet_model
del absolute_import
del division
del print_function
...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark): ...@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods=[shakespeare_main.define_flags]) flag_methods=[shakespeare_main.define_flags])
def _run_and_report_benchmark(self, def _run_and_report_benchmark(self,
top_1_train_min=0.923, top_1_train_min=0.91,
top_1_train_max=0.93, top_1_train_max=0.94,
warmup=1, warmup=1,
log_steps=100): log_steps=100):
"""Report benchmark results by writing to local protobuf file. """Report benchmark results by writing to local protobuf file.
......
...@@ -79,8 +79,41 @@ class _StateKeys(object): ...@@ -79,8 +79,41 @@ class _StateKeys(object):
class SequenceBeamSearch(object): class SequenceBeamSearch(object):
"""Implementation of beam search loop.""" """Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, def __init__(self,
beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32): symbols_to_logits_fn,
vocab_size,
batch_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch_size = batch_size self.batch_size = batch_size
...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object): ...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha self.alpha = alpha
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.eos_id = eos_id self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache): def search(self, initial_ids, initial_cache):
...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object): ...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1] # Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2) alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
# Create tensor for storing initial log probabilities. # Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0 # Assume initial_ids are prob 1.0
...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object): ...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may # 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size). # depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations. # 2) the dimension may have different values on different iterations.
state_shape_invariants = { if self.padded_decode:
_StateKeys.CUR_INDEX: tf.TensorShape([]), state_shape_invariants = {
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]), _StateKeys.CUR_INDEX:
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]), tf.TensorShape([]),
_StateKeys.ALIVE_CACHE: nest.map_structure( _StateKeys.ALIVE_SEQ:
_get_shape_keep_last_dim, alive_cache), tf.TensorShape(
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]), [self.batch_size, self.beam_size,
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]), self.max_decode_length + 1]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size]) _StateKeys.ALIVE_LOG_PROBS:
} tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([self.batch_size, self.beam_size])
}
else:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants return state, state_shape_invariants
...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object): ...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new # Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time. # cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] if self.padded_decode:
flat_ids = tf.reshape(
tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
[self.batch_size * self.beam_size, -1])
else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache) flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object): ...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences # Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2) if self.padded_decode:
topk_seq = tf.concat([topk_seq, topk_ids], axis=2) topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object): ...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length. # First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1] # New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat( if not self.padded_decode:
[finished_seq, finished_seq = tf.concat([
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
],
axis=2)
# Calculate new seq scores from log probabilities. # Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype) length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object): ...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def sequence_beam_search( def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id): alpha, max_decode_length, eos_id, padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of the vocabulary, used for topk
max_decode_length: maximum length to decoded sequence computation.
eos_id: int id of eos token, used to determine when a sequence has finished beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id) beam_size, alpha, max_decode_length, eos_id,
padded_decode)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor): ...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list) return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor): def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension. """Reshapes first two dimensions in to single dimension.
......
...@@ -32,6 +32,7 @@ from absl import flags ...@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
from official.r1.utils import export
from official.transformer import compute_bleu from official.transformer import compute_bleu
from official.transformer import translate from official.transformer import translate
from official.transformer.model import model_params from official.transformer.model import model_params
...@@ -41,7 +42,6 @@ from official.transformer.utils import metrics ...@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from official.transformer.utils import schedule from official.transformer.utils import schedule
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
...@@ -56,7 +56,7 @@ PARAMS_MAP = { ...@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS = 10 DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9) INF = 1000000000 # 1e9
BLEU_DIR = "bleu" BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item # Dictionary containing tensors that are logged by the logging hooks. Each item
......
...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer): ...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size]) return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None): def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y. """Apply attention mechanism to x and y.
Args: Args:
x: a tensor with shape [batch_size, length_x, hidden_size] x: A tensor with shape [batch_size, length_x, hidden_size].
y: a tensor with shape [batch_size, length_y, hidden_size] y: A tensor with shape [batch_size, length_y, hidden_size].
bias: attention bias that will be added to the result of the dot product. bias: A bool, the attention bias that will be added to the result of the
training: boolean, whether in training mode or not. dot product.
cache: (Used during prediction) dictionary with tensors containing results training: A bool, whether in training mode or not.
of previous attentions. The dictionary must have the items: cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels], {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]} "v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length. where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns: Returns:
Attention layer output with shape [batch_size, length_x, hidden_size] Attention layer output with shape [batch_size, length_x, hidden_size]
""" """
# Linearly project the query (q), key (k) and value (v) using different # Linearly project the query, key and value using different learned
# learned projections. This is in preparation of splitting them into # projections. This is in preparation of splitting them into multiple
# multiple heads. Multi-head attention uses multiple queries, keys, and # heads. Multi-head attention uses multiple queries, keys, and values
# values rather than regular attention (which uses a single q, k, v). # rather than regular attention (which uses a single query, key, value).
q = self.q_dense_layer(x) query = self.q_dense_layer(x)
k = self.k_dense_layer(y) key = self.k_dense_layer(y)
v = self.v_dense_layer(y) value = self.v_dense_layer(y)
if cache is not None: if cache is not None:
# Combine cached keys and values with new keys and values. # Combine cached keys and values with new keys and values.
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1) if decode_loop_step is not None:
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1) cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache # Update cache
cache["k"] = k cache["k"] = key
cache["v"] = v cache["v"] = value
# Split q, k, v into heads. # Split query, key, value into heads.
q = self.split_heads(q) query = self.split_heads(query)
k = self.split_heads(k) key = self.split_heads(key)
v = self.split_heads(v) value = self.split_heads(value)
# Scale q to prevent the dot product between q and k from growing too large. # Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads) depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5 query *= depth ** -0.5
# Calculate dot product attention # Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) logits = tf.matmul(query, key, transpose_b=True)
logits += bias logits += bias
# Note that softmax internally performs math operations using float32 # Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input # for numeric stability. When training with float16, we keep the input
...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights") weights = tf.nn.softmax(logits, name="attention_weights")
if training: if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout) weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v) attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size] # Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output) attention_output = self.combine_heads(attention_output)
...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer): ...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention): class SelfAttention(Attention):
"""Multiheaded self-attention layer.""" """Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None): def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache) return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return finished_seq, finished_scores return finished_seq, finished_scores
def sequence_beam_search( def sequence_beam_search(symbols_to_logits_fn,
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, initial_ids,
alpha, max_decode_length, eos_id, dtype="float32"): initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of tokens.
max_decode_length: maximum length to decoded sequence beam_size: An integer, the number of beams.
eos_id: int id of eos token, used to determine when a sequence has finished, alpha: A float, defining the strength of length normalization.
dtype: The dtype to use. max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
if misc.is_v2(): if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
else: else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -24,24 +24,14 @@ import tensorflow as tf ...@@ -24,24 +24,14 @@ import tensorflow as tf
class EmbeddingSharedWeights(tf.keras.layers.Layer): class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights.""" """Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size, dtype=None): def __init__(self, vocab_size, hidden_size):
"""Specify characteristic parameters of embedding layer. """Specify characteristic parameters of embedding layer.
Args: Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000) vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024) hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
""" """
if dtype == tf.float16: super(EmbeddingSharedWeights, self).__init__()
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype = tf.keras.mixed_precision.experimental.Policy(
"float16_with_float32_vars")
super(EmbeddingSharedWeights, self).__init__(dtype=dtype)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight( self.shared_weights = self.add_weight(
"weights", "weights",
shape=[self.vocab_size, self.hidden_size], shape=[self.vocab_size, self.hidden_size],
dtype="float32",
initializer=tf.random_normal_initializer( initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5)) mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape) super(EmbeddingSharedWeights, self).build(input_shape)
......
...@@ -192,6 +192,29 @@ def define_transformer_flags(): ...@@ -192,6 +192,29 @@ def define_transformer_flags():
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test ' 'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.')) 'all use 1VM config.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
......
...@@ -49,8 +49,10 @@ def create_model(params, is_train): ...@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits) logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype=tf.float32)(logits)
model = tf.keras.Model([inputs, targets], logits) model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss( loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size) logits, targets, label_smoothing, vocab_size)
model.add_loss(loss) model.add_loss(loss)
...@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model): ...@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super(Transformer, self).__init__(name=name) super(Transformer, self).__init__(name=name)
self.params = params self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"], dtype=params["dtype"]) params["vocab_size"], params["hidden_size"])
self.encoder_stack = EncoderStack(params) self.encoder_stack = EncoderStack(params)
self.decoder_stack = DecoderStack(params) self.decoder_stack = DecoderStack(params)
...@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model): ...@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length] outputs: [batch_size, decoded length]
scores: [batch_size, float]} scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32. Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] inputs, targets = inputs[0], inputs[1]
else: else:
inputs, targets = inputs[0], None inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems. # Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well. # Other reasonable initializers may also work just as well.
...@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model): ...@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"]) max_decode_length, dtype=self.params["dtype"])
# TODO(b/139770046): Refactor code with better naming of i.
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape [batch_size * ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1] beam_size, i + 1].
i: Loop index i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
...@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model): ...@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal. # Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack( decoder_outputs = self.decoder_stack(
decoder_input, decoder_input,
cache.get("encoder_outputs"), cache.get("encoder_outputs"),
self_attention_bias, self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache.get("encoder_decoder_attention_bias"),
training=training, training=training,
cache=cache) cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
return logits, cache return logits, cache
...@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model): ...@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0] if self.params["padded_decode"]:
input_length = tf.shape(encoder_outputs)[1] batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"] max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"]) self.params["dtype"])
...@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model): ...@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer. # Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
cache = { cache = {
"layer_%d" % layer: { "layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]], "k":
dtype=self.params["dtype"]), tf.zeros([
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]], batch_size, init_decode_length, self.params["hidden_size"]
dtype=self.params["dtype"]) ],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"]) } for layer in range(self.params["num_hidden_layers"])
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
...@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model): ...@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model):
alpha=self.params["alpha"], alpha=self.params["alpha"],
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=EOS_ID, eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"]) dtype=self.params["dtype"])
# Get the top sequence for each batch element # Get the top sequence for each batch element
...@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias, decoder_self_attention_bias,
attention_bias, attention_bias,
training, training,
cache=None): cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size] decoder_inputs: A tensor with shape
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size] [batch_size, target_length, hidden_size].
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1, encoder_outputs: A tensor with shape
target_len, target_length] [batch_size, input_length, hidden_size]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1, decoder_self_attention_bias: A tensor with shape
1, input_length] [1, 1, target_len, target_length], the bias for decoder self-attention
training: boolean, whether in training mode or not. layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels], {layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}, "v": A tensor with shape [batch_size, i, value_channels]},
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns: Returns:
Output of decoder layer stack. Output of decoder layer stack.
...@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs, decoder_inputs,
decoder_self_attention_bias, decoder_self_attention_bias,
training=training, training=training,
cache=layer_cache) cache=layer_cache,
decode_loop_step=decode_loop_step)
with tf.name_scope("encdec_attention"): with tf.name_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer( decoder_inputs = enc_dec_attention_layer(
decoder_inputs, decoder_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment