Unverified Commit a35e09d2 authored by Vinh Nguyen's avatar Vinh Nguyen Committed by GitHub
Browse files

Merge branch 'master' into amp_resnet50

parents d5722dcd 1f5a5e9d
......@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.export import export
from official.r1.utils import export
class ExportUtilsTest(tf.test.TestCase):
......
......@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order
from official.datasets import movielens
from official.utils.data import file_io
from official.r1.utils.data import file_io
from official.utils.flags import core as flags_core
......
......@@ -65,8 +65,8 @@ def prepare_raw_data(flag_obj):
data_processing_params = {
"train_epochs": flag_obj.num_train_epochs,
"batch_size": flag_obj.prebatch_size,
"eval_batch_size": flag_obj.prebatch_size,
"batch_size": flag_obj.train_prebatch_size,
"eval_batch_size": flag_obj.eval_prebatch_size,
"batches_per_step": 1,
"stream_files": True,
"num_neg": flag_obj.num_negative_samples,
......
......@@ -154,8 +154,10 @@ def define_ncf_flags():
intra_op=False,
synthetic_data=True,
max_train_steps=False,
dtype=False,
dtype=True,
all_reduce_alg=False,
loss_scale=True,
dynamic_loss_scale=True,
enable_xla=True,
force_v2_in_keras_compile=True
)
......
......@@ -21,7 +21,6 @@ from __future__ import print_function
import functools
# pylint: disable=g-bad-import-order
import numpy as np
import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
......@@ -42,6 +41,9 @@ def create_dataset_from_tf_record_files(input_file_pattern,
def make_dataset(files_dataset, shard_index):
"""Returns dataset for sharded tf record files."""
if pre_batch_size != batch_size:
raise ValueError("Pre-batch ({}) size is not equal to batch "
"size ({})".format(pre_batch_size, batch_size))
files_dataset = files_dataset.shard(NUM_SHARDS, shard_index)
dataset = files_dataset.interleave(tf.data.TFRecordDataset)
decode_fn = functools.partial(
......@@ -50,8 +52,6 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training=is_training)
dataset = dataset.map(
decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.apply(tf.data.experimental.unbatch())
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset
dataset = tf.data.Dataset.range(NUM_SHARDS)
......
......@@ -117,7 +117,7 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
"""
self._run_and_report_benchmark(hr_at_10_min=0.61)
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.640):
def _run_and_report_benchmark(self, hr_at_10_min=0.630, hr_at_10_max=0.645):
"""Run test and report results.
Note: Target is 0.635, but some runs are below that level. Until we have
......@@ -263,6 +263,15 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_1_gpu_ctl_run_eagerly_mlperf_like(self):
"""1 GPU using CTL with eager and distribution strategy."""
self._setup()
......@@ -279,6 +288,16 @@ class NCFKerasAccuracy(NCFKerasBenchmarkBase):
FLAGS.train_epochs = 7
self._run_and_report_benchmark_mlperf_like()
def benchmark_xla_1_gpu_ctl_fp16_mlperf_like(self):
"""1 GPU using CTL with XLA."""
self._setup()
FLAGS.keras_use_ctl = True
FLAGS.enable_xla = True
FLAGS.train_epochs = 7
FLAGS.dtype = 'fp16'
FLAGS.loss_scale = 8192
self._run_and_report_benchmark_mlperf_like()
def benchmark_8_gpu_mlperf_like(self):
"""8 GPU using keras fit/compile."""
self._setup()
......
......@@ -42,6 +42,7 @@ from official.utils.logs import mlperf_helper
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.utils.flags import core as flags_core
from official.utils.misc import tpu_lib
FLAGS = flags.FLAGS
......@@ -267,6 +268,12 @@ def run_ncf(_):
beta_1=params["beta1"],
beta_2=params["beta2"],
epsilon=params["epsilon"])
if FLAGS.dtype == "fp16":
optimizer = \
tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer,
loss_scale=flags_core.get_loss_scale(FLAGS,
default_for_fp16="dynamic"))
if params["keras_use_ctl"]:
train_loss, eval_results = run_ncf_custom_training(
......@@ -371,8 +378,12 @@ def run_ncf_custom_training(params,
softmax_logits,
sample_weight=features[rconst.VALID_POINT_MASK])
loss *= (1.0 / params["batch_size"])
if FLAGS.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, keras_model.trainable_variables)
if FLAGS.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
# Converting gradients to dense form helps in perf on GPU for NCF
grads = neumf_model.sparse_to_dense_grads(
list(zip(grads, keras_model.trainable_variables)))
......
......@@ -27,3 +27,6 @@ def define_ctl_flags():
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
help='Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.')
......@@ -22,7 +22,7 @@ import time
from absl import flags
import tensorflow as tf
from official.resnet.keras import keras_common
from official.vision.image_classification import common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [
ctl_common.define_ctl_flags,
keras_common.define_keras_flags
common.define_keras_flags
]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
......@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
ctl_common.define_ctl_flags,
keras_common.define_keras_flags
common.define_keras_flags
]
super(Resnet50CtlBenchmarkBase, self).__init__(
......@@ -215,6 +215,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_eager')
FLAGS.batch_size = 64
FLAGS.use_tf_function = False
FLAGS.single_l2_loss_op = True
self._run_and_report_benchmark()
def benchmark_8_gpu(self):
......
......@@ -24,10 +24,10 @@ from absl import logging
import tensorflow as tf
from official.resnet.ctl import ctl_common
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.resnet.keras import resnet_model
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.vision.image_classification import resnet_imagenet_main
from official.vision.image_classification import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
......@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
......@@ -137,6 +137,10 @@ def run(flags_obj):
Returns:
Dictionary of training and eval stats.
"""
keras_utils.set_session_config(
enable_eager=flags_obj.enable_eager,
enable_xla=flags_obj.enable_xla)
dtype = flags_core.get_tf_dtype(flags_obj)
# TODO(anj-s): Set data_format without using Keras.
......@@ -163,10 +167,11 @@ def run(flags_obj):
with strategy_scope:
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size)
dtype=dtype, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD(
learning_rate=keras_common.BASE_LEARNING_RATE, momentum=0.9,
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9,
nesterov=True)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
......@@ -175,6 +180,8 @@ def run(flags_obj):
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
trainable_variables = model.trainable_variables
def train_step(train_ds_inputs):
"""Training StepFn."""
def step_fn(inputs):
......@@ -185,13 +192,22 @@ def run(flags_obj):
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss1 = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
loss2 = (tf.reduce_sum(model.losses) /
tf.distribute.get_strategy().num_replicas_in_sync)
loss = loss1 + loss2
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
grads = tape.gradient(loss, trainable_variables)
optimizer.apply_gradients(zip(grads, trainable_variables))
training_accuracy.update_state(labels, logits)
return loss
......@@ -232,7 +248,7 @@ def run(flags_obj):
training_accuracy.reset_states()
for step in range(train_steps):
optimizer.lr = keras_imagenet_main.learning_rate_schedule(
optimizer.lr = resnet_imagenet_main.learning_rate_schedule(
epoch, step, train_steps, flags_obj.batch_size)
time_callback.on_batch_begin(step+epoch*train_steps)
......@@ -281,7 +297,7 @@ def main(_):
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
keras_common.define_keras_flags()
common.define_keras_flags()
ctl_common.define_ctl_flags()
flags.adopt_module_key_flags(keras_common)
flags.adopt_module_key_flags(ctl_common)
......
......@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import common
from official.utils.misc import keras_utils
from official.utils.testing import integration
......@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
keras_common.define_keras_flags()
common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from official.vision.image_classification import cifar_preprocessing
from official.vision.image_classification import common as keras_common
from official.vision.image_classification import imagenet_preprocessing
from official.vision.image_classification import resnet_cifar_main as keras_cifar_main
from official.vision.image_classification import resnet_cifar_model
from official.vision.image_classification import resnet_imagenet_main as keras_imagenet_main
from official.vision.image_classification import resnet_model
del absolute_import
del division
del print_function
......@@ -41,8 +41,8 @@ class ShakespeareBenchmarkBase(PerfZeroBenchmark):
flag_methods=[shakespeare_main.define_flags])
def _run_and_report_benchmark(self,
top_1_train_min=0.923,
top_1_train_max=0.93,
top_1_train_min=0.91,
top_1_train_max=0.94,
warmup=1,
log_steps=100):
"""Report benchmark results by writing to local protobuf file.
......
......@@ -79,8 +79,41 @@ class _StateKeys(object):
class SequenceBeamSearch(object):
"""Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32):
def __init__(self,
symbols_to_logits_fn,
vocab_size,
batch_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size
self.batch_size = batch_size
......@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha
self.max_decode_length = max_decode_length
self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache):
......@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
......@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations.
state_shape_invariants = {
_StateKeys.CUR_INDEX: tf.TensorShape([]),
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE: nest.map_structure(
_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size])
}
if self.padded_decode:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([self.batch_size, self.beam_size])
}
else:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants
......@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
if self.padded_decode:
flat_ids = tf.reshape(
tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
[self.batch_size * self.beam_size, -1])
else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
......@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
if self.padded_decode:
topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
......@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat(
[finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
if not self.padded_decode:
finished_seq = tf.concat([
finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
],
axis=2)
# Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
......@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id):
alpha, max_decode_length, eos_id, padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id)
beam_size, alpha, max_decode_length, eos_id,
padded_decode)
return sbs.search(initial_ids, initial_cache)
......@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension.
......
......@@ -32,6 +32,7 @@ from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.r1.utils import export
from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.model import model_params
......@@ -41,7 +42,6 @@ from official.transformer.utils import metrics
from official.transformer.utils import schedule
from official.transformer.utils import tokenizer
from official.utils.accelerator import tpu as tpu_util
from official.utils.export import export
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
......@@ -56,7 +56,7 @@ PARAMS_MAP = {
DEFAULT_TRAIN_EPOCHS = 10
INF = int(1e9)
INF = 1000000000 # 1e9
BLEU_DIR = "bleu"
# Dictionary containing tensors that are logged by the logging hooks. Each item
......
......@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None):
def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y.
Args:
x: a tensor with shape [batch_size, length_x, hidden_size]
y: a tensor with shape [batch_size, length_y, hidden_size]
bias: attention bias that will be added to the result of the dot product.
training: boolean, whether in training mode or not.
cache: (Used during prediction) dictionary with tensors containing results
of previous attentions. The dictionary must have the items:
x: A tensor with shape [batch_size, length_x, hidden_size].
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: A bool, the attention bias that will be added to the result of the
dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_x, hidden_size]
"""
# Linearly project the query (q), key (k) and value (v) using different
# learned projections. This is in preparation of splitting them into
# multiple heads. Multi-head attention uses multiple queries, keys, and
# values rather than regular attention (which uses a single q, k, v).
q = self.q_dense_layer(x)
k = self.k_dense_layer(y)
v = self.v_dense_layer(y)
# Linearly project the query, key and value using different learned
# projections. This is in preparation of splitting them into multiple
# heads. Multi-head attention uses multiple queries, keys, and values
# rather than regular attention (which uses a single query, key, value).
query = self.q_dense_layer(x)
key = self.k_dense_layer(y)
value = self.v_dense_layer(y)
if cache is not None:
# Combine cached keys and values with new keys and values.
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1)
if decode_loop_step is not None:
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache
cache["k"] = k
cache["v"] = v
cache["k"] = key
cache["v"] = value
# Split q, k, v into heads.
q = self.split_heads(q)
k = self.split_heads(k)
v = self.split_heads(v)
# Split query, key, value into heads.
query = self.split_heads(query)
key = self.split_heads(key)
value = self.split_heads(value)
# Scale q to prevent the dot product between q and k from growing too large.
# Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5
query *= depth ** -0.5
# Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True)
logits = tf.matmul(query, key, transpose_b=True)
logits += bias
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
......@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights")
if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v)
attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output)
......@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None):
return super(SelfAttention, self).call(x, x, bias, training, cache)
def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
......@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return finished_seq, finished_scores
def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id, dtype="float32"):
def sequence_beam_search(symbols_to_logits_fn,
initial_ids,
initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode=False,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability.
Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index]
index -> [] (scalar)
cache -> nested dictionary of tensors [batch_size * beam_size, ...]
The function must return logits and new cache.
logits -> [batch * beam_size, vocab_size]
new cache -> same shape/structure as inputted cache
initial_ids: Starting ids for each batch item.
int32 tensor with shape [batch_size]
initial_cache: dict containing starting decoder variables information
vocab_size: int size of tokens
beam_size: int number of beams
alpha: float defining the strength of length normalization
max_decode_length: maximum length to decoded sequence
eos_id: int id of eos token, used to determine when a sequence has finished,
dtype: The dtype to use.
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and new cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> A nested dictionary with the same shape/structure as the
inputted cache.
initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
each batch item.
initial_cache: A dictionary, containing starting decoder variables
information.
vocab_size: An integer, the size of tokens.
beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size]
"""
batch_size = tf.shape(initial_ids)[0]
batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
dtype)
padded_decode, dtype)
else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id,
dtype)
padded_decode, dtype)
return sbs.search(initial_ids, initial_cache)
......
......@@ -24,24 +24,14 @@ import tensorflow as tf
class EmbeddingSharedWeights(tf.keras.layers.Layer):
"""Calculates input embeddings and pre-softmax linear with shared weights."""
def __init__(self, vocab_size, hidden_size, dtype=None):
def __init__(self, vocab_size, hidden_size):
"""Specify characteristic parameters of embedding layer.
Args:
vocab_size: Number of tokens in the embedding. (Typically ~32,000)
hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
dtype: The dtype of the layer: float16 or float32.
"""
if dtype == tf.float16:
# We cannot rely on the global policy of "infer_with_float32_vars", as
# this layer is called on both int64 inputs and floating-point inputs.
# If "infer_with_float32_vars" is used, the dtype will be inferred to be
# int64, which means floating-point inputs would not be casted.
# TODO(b/138859351): Remove this logic once we stop using the deprecated
# "infer_with_float32_vars" policy
dtype = tf.keras.mixed_precision.experimental.Policy(
"float16_with_float32_vars")
super(EmbeddingSharedWeights, self).__init__(dtype=dtype)
super(EmbeddingSharedWeights, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
......@@ -53,7 +43,6 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
self.shared_weights = self.add_weight(
"weights",
shape=[self.vocab_size, self.hidden_size],
dtype="float32",
initializer=tf.random_normal_initializer(
mean=0., stddev=self.hidden_size**-0.5))
super(EmbeddingSharedWeights, self).build(input_shape)
......
......@@ -192,6 +192,29 @@ def define_transformer_flags():
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
......
......@@ -49,8 +49,10 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits")(logits)
logits = tf.keras.layers.Lambda(lambda x: x, name="logits",
dtype=tf.float32)(logits)
model = tf.keras.Model([inputs, targets], logits)
# TODO(reedwm): Can we do this loss in float16 instead of float32?
loss = metrics.transformer_loss(
logits, targets, label_smoothing, vocab_size)
model.add_loss(loss)
......@@ -85,7 +87,7 @@ class Transformer(tf.keras.Model):
super(Transformer, self).__init__(name=name)
self.params = params
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
params["vocab_size"], params["hidden_size"], dtype=params["dtype"])
params["vocab_size"], params["hidden_size"])
self.encoder_stack = EncoderStack(params)
self.decoder_stack = DecoderStack(params)
......@@ -112,11 +114,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length]
scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
else:
inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
......@@ -225,13 +238,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"])
# TODO(b/139770046): Refactor code with better naming of i.
def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs.
Args:
ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1]
i: Loop index
beam_size, i + 1].
i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values.
......@@ -245,16 +259,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack(
decoder_input,
cache.get("encoder_outputs"),
self_attention_bias,
cache.get("encoder_decoder_attention_bias"),
training=training,
cache=cache)
cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
logits = tf.squeeze(logits, axis=[1])
return logits, cache
......@@ -263,8 +290,12 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
if self.params["padded_decode"]:
batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"])
......@@ -277,12 +308,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
cache = {
"layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]],
dtype=self.params["dtype"]),
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]],
dtype=self.params["dtype"])
"k":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"])
}
# pylint: enable=g-complex-comprehension
......@@ -301,6 +340,7 @@ class Transformer(tf.keras.Model):
alpha=self.params["alpha"],
max_decode_length=max_decode_length,
eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"])
# Get the top sequence for each batch element
......@@ -505,22 +545,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias,
attention_bias,
training,
cache=None):
cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks.
Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1,
target_len, target_length]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1,
1, input_length]
training: boolean, whether in training mode or not.
decoder_inputs: A tensor with shape
[batch_size, target_length, hidden_size].
encoder_outputs: A tensor with shape
[batch_size, input_length, hidden_size]
decoder_self_attention_bias: A tensor with shape
[1, 1, target_len, target_length], the bias for decoder self-attention
layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]},
{layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
"v": A tensor with shape [batch_size, i, value_channels]},
...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns:
Output of decoder layer stack.
......@@ -540,7 +586,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs,
decoder_self_attention_bias,
training=training,
cache=layer_cache)
cache=layer_cache,
decode_loop_step=decode_loop_step)
with tf.name_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer(
decoder_inputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment