Commit 640ff472 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 264853703
parent 4a1354fe
...@@ -79,8 +79,41 @@ class _StateKeys(object): ...@@ -79,8 +79,41 @@ class _StateKeys(object):
class SequenceBeamSearch(object): class SequenceBeamSearch(object):
"""Implementation of beam search loop.""" """Implementation of beam search loop."""
def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, def __init__(self,
beam_size, alpha, max_decode_length, eos_id, dtype=tf.float32): symbols_to_logits_fn,
vocab_size,
batch_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32):
"""Initialize sequence beam search.
Args:
symbols_to_logits_fn: A function to provide logits, which is the
interface to the Transformer model. The passed in arguments are:
ids -> A tensor with shape [batch_size * beam_size, index].
index -> A scalar.
cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return a tuple of logits and the updated cache:
logits -> A tensor with shape [batch * beam_size, vocab_size].
updated cache -> A nested dictionary with the same structure as the
input cache.
vocab_size: An integer, the size of the vocabulary, used for topk
computation.
batch_size: An integer, the decode batch size.
beam_size: An integer, number of beams for beam search.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum number of steps to decode
a sequence.
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
self.symbols_to_logits_fn = symbols_to_logits_fn self.symbols_to_logits_fn = symbols_to_logits_fn
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch_size = batch_size self.batch_size = batch_size
...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object): ...@@ -88,6 +121,7 @@ class SequenceBeamSearch(object):
self.alpha = alpha self.alpha = alpha
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.eos_id = eos_id self.eos_id = eos_id
self.padded_decode = padded_decode
self.dtype = tf.as_dtype(dtype) self.dtype = tf.as_dtype(dtype)
def search(self, initial_ids, initial_cache): def search(self, initial_ids, initial_cache):
...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object): ...@@ -140,6 +174,8 @@ class SequenceBeamSearch(object):
# Create alive sequence with shape [batch_size, beam_size, 1] # Create alive sequence with shape [batch_size, beam_size, 1]
alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2) alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
# Create tensor for storing initial log probabilities. # Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0 # Assume initial_ids are prob 1.0
...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object): ...@@ -178,16 +214,44 @@ class SequenceBeamSearch(object):
# 1) the dimension's value is a tensor that remains the same but may # 1) the dimension's value is a tensor that remains the same but may
# depend on the input sequence to the model (e.g. batch size). # depend on the input sequence to the model (e.g. batch size).
# 2) the dimension may have different values on different iterations. # 2) the dimension may have different values on different iterations.
state_shape_invariants = { if self.padded_decode:
_StateKeys.CUR_INDEX: tf.TensorShape([]), state_shape_invariants = {
_StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]), _StateKeys.CUR_INDEX:
_StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]), tf.TensorShape([]),
_StateKeys.ALIVE_CACHE: nest.map_structure( _StateKeys.ALIVE_SEQ:
_get_shape_keep_last_dim, alive_cache), tf.TensorShape(
_StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]), [self.batch_size, self.beam_size,
_StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]), self.max_decode_length + 1]),
_StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size]) _StateKeys.ALIVE_LOG_PROBS:
} tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[self.batch_size, self.beam_size,
self.max_decode_length + 1]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([self.batch_size, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([self.batch_size, self.beam_size])
}
else:
state_shape_invariants = {
_StateKeys.CUR_INDEX:
tf.TensorShape([]),
_StateKeys.ALIVE_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([None, self.beam_size]),
_StateKeys.ALIVE_CACHE:
nest.map_structure(_get_shape_keep_last_dim, alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape([None, self.beam_size, None]),
_StateKeys.FINISHED_SCORES:
tf.TensorShape([None, self.beam_size]),
_StateKeys.FINISHED_FLAGS:
tf.TensorShape([None, self.beam_size])
}
return state, state_shape_invariants return state, state_shape_invariants
...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object): ...@@ -297,7 +361,12 @@ class SequenceBeamSearch(object):
# Get logits for the next candidate IDs for the alive sequences. Get the new # Get logits for the next candidate IDs for the alive sequences. Get the new
# cache values at the same time. # cache values at the same time.
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] if self.padded_decode:
flat_ids = tf.reshape(
tf.slice(alive_seq, [0, 0, i], [self.batch_size, self.beam_size, 1]),
[self.batch_size * self.beam_size, -1])
else:
flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache) flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i, flat_cache)
...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object): ...@@ -331,8 +400,13 @@ class SequenceBeamSearch(object):
# Append the most probable IDs to the topk sequences # Append the most probable IDs to the topk sequences
topk_ids = topk_indices % self.vocab_size topk_ids = topk_indices % self.vocab_size
topk_ids = tf.expand_dims(topk_ids, axis=2) if self.padded_decode:
topk_seq = tf.concat([topk_seq, topk_ids], axis=2) topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
topk_seq = tf.tensor_scatter_update(topk_seq, [i + 1], topk_ids)
topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])
else:
topk_ids = tf.expand_dims(topk_ids, axis=2)
topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
return topk_seq, topk_log_probs, new_cache return topk_seq, topk_log_probs, new_cache
def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object): ...@@ -388,9 +462,12 @@ class SequenceBeamSearch(object):
# First append a column of 0-ids to finished_seq to increment the length. # First append a column of 0-ids to finished_seq to increment the length.
# New shape of finished_seq: [batch_size, beam_size, i + 1] # New shape of finished_seq: [batch_size, beam_size, i + 1]
finished_seq = tf.concat( if not self.padded_decode:
[finished_seq, finished_seq = tf.concat([
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) finished_seq,
tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)
],
axis=2)
# Calculate new seq scores from log probabilities. # Calculate new seq scores from log probabilities.
length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype) length_norm = _length_normalization(self.alpha, i + 1, dtype=self.dtype)
...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object): ...@@ -420,34 +497,43 @@ class SequenceBeamSearch(object):
def sequence_beam_search( def sequence_beam_search(
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
alpha, max_decode_length, eos_id): alpha, max_decode_length, eos_id, padded_decode=False):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of the vocabulary, used for topk
max_decode_length: maximum length to decoded sequence computation.
eos_id: int id of eos token, used to determine when a sequence has finished beam_size: An integer, the number of beams.
alpha: A float, defining the strength of length normalization.
max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id) beam_size, alpha, max_decode_length, eos_id,
padded_decode)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor): ...@@ -502,6 +588,11 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list) return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor): def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension. """Reshapes first two dimensions in to single dimension.
......
...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer): ...@@ -102,51 +102,67 @@ class Attention(tf.keras.layers.Layer):
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
return tf.reshape(x, [batch_size, length, self.hidden_size]) return tf.reshape(x, [batch_size, length, self.hidden_size])
def call(self, x, y, bias, training, cache=None): def call(self, x, y, bias, training, cache=None, decode_loop_step=None):
"""Apply attention mechanism to x and y. """Apply attention mechanism to x and y.
Args: Args:
x: a tensor with shape [batch_size, length_x, hidden_size] x: A tensor with shape [batch_size, length_x, hidden_size].
y: a tensor with shape [batch_size, length_y, hidden_size] y: A tensor with shape [batch_size, length_y, hidden_size].
bias: attention bias that will be added to the result of the dot product. bias: A bool, the attention bias that will be added to the result of the
training: boolean, whether in training mode or not. dot product.
cache: (Used during prediction) dictionary with tensors containing results training: A bool, whether in training mode or not.
of previous attentions. The dictionary must have the items: cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels], {"k": tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]} "v": tensor with shape [batch_size, i, value_channels]}
where i is the current decoded length. where i is the current decoded length.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns: Returns:
Attention layer output with shape [batch_size, length_x, hidden_size] Attention layer output with shape [batch_size, length_x, hidden_size]
""" """
# Linearly project the query (q), key (k) and value (v) using different # Linearly project the query, key and value using different learned
# learned projections. This is in preparation of splitting them into # projections. This is in preparation of splitting them into multiple
# multiple heads. Multi-head attention uses multiple queries, keys, and # heads. Multi-head attention uses multiple queries, keys, and values
# values rather than regular attention (which uses a single q, k, v). # rather than regular attention (which uses a single query, key, value).
q = self.q_dense_layer(x) query = self.q_dense_layer(x)
k = self.k_dense_layer(y) key = self.k_dense_layer(y)
v = self.v_dense_layer(y) value = self.v_dense_layer(y)
if cache is not None: if cache is not None:
# Combine cached keys and values with new keys and values. # Combine cached keys and values with new keys and values.
k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1) if decode_loop_step is not None:
v = tf.concat([tf.cast(cache["v"], k.dtype), v], axis=1) cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache # Update cache
cache["k"] = k cache["k"] = key
cache["v"] = v cache["v"] = value
# Split q, k, v into heads. # Split query, key, value into heads.
q = self.split_heads(q) query = self.split_heads(query)
k = self.split_heads(k) key = self.split_heads(key)
v = self.split_heads(v) value = self.split_heads(value)
# Scale q to prevent the dot product between q and k from growing too large. # Scale query to prevent the dot product between query and key from growing
# too large.
depth = (self.hidden_size // self.num_heads) depth = (self.hidden_size // self.num_heads)
q *= depth ** -0.5 query *= depth ** -0.5
# Calculate dot product attention # Calculate dot product attention
logits = tf.matmul(q, k, transpose_b=True) logits = tf.matmul(query, key, transpose_b=True)
logits += bias logits += bias
# Note that softmax internally performs math operations using float32 # Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input # for numeric stability. When training with float16, we keep the input
...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -154,7 +170,7 @@ class Attention(tf.keras.layers.Layer):
weights = tf.nn.softmax(logits, name="attention_weights") weights = tf.nn.softmax(logits, name="attention_weights")
if training: if training:
weights = tf.nn.dropout(weights, rate=self.attention_dropout) weights = tf.nn.dropout(weights, rate=self.attention_dropout)
attention_output = tf.matmul(weights, v) attention_output = tf.matmul(weights, value)
# Recombine heads --> [batch_size, length, hidden_size] # Recombine heads --> [batch_size, length, hidden_size]
attention_output = self.combine_heads(attention_output) attention_output = self.combine_heads(attention_output)
...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer): ...@@ -167,5 +183,6 @@ class Attention(tf.keras.layers.Layer):
class SelfAttention(Attention): class SelfAttention(Attention):
"""Multiheaded self-attention layer.""" """Multiheaded self-attention layer."""
def call(self, x, bias, training, cache=None): def call(self, x, bias, training, cache=None, decode_loop_step=None):
return super(SelfAttention, self).call(x, x, bias, training, cache) return super(SelfAttention, self).call(x, x, bias, training, cache,
decode_loop_step)
...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -55,43 +55,58 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
return finished_seq, finished_scores return finished_seq, finished_scores
def sequence_beam_search( def sequence_beam_search(symbols_to_logits_fn,
symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, initial_ids,
alpha, max_decode_length, eos_id, dtype="float32"): initial_cache,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability. """Search for sequence of subtoken ids with the largest probability.
Args: Args:
symbols_to_logits_fn: A function that takes in ids, index, and cache as symbols_to_logits_fn: A function that takes in ids, index, and cache as
arguments. The passed in arguments will have shape: arguments. The passed in arguments will have shape:
ids -> [batch_size * beam_size, index] ids -> A tensor with shape [batch_size * beam_size, index].
index -> [] (scalar) index -> A scalar.
cache -> nested dictionary of tensors [batch_size * beam_size, ...] cache -> A nested dictionary of tensors [batch_size * beam_size, ...].
The function must return logits and new cache. The function must return a tuple of logits and new cache:
logits -> [batch * beam_size, vocab_size] logits -> A tensor with shape [batch * beam_size, vocab_size].
new cache -> same shape/structure as inputted cache new cache -> A nested dictionary with the same shape/structure as the
initial_ids: Starting ids for each batch item. inputted cache.
int32 tensor with shape [batch_size] initial_ids: An int32 tensor with shape [batch_size]. Starting ids for
initial_cache: dict containing starting decoder variables information each batch item.
vocab_size: int size of tokens initial_cache: A dictionary, containing starting decoder variables
beam_size: int number of beams information.
alpha: float defining the strength of length normalization vocab_size: An integer, the size of tokens.
max_decode_length: maximum length to decoded sequence beam_size: An integer, the number of beams.
eos_id: int id of eos token, used to determine when a sequence has finished, alpha: A float, defining the strength of length normalization.
dtype: The dtype to use. max_decode_length: An integer, the maximum length to decoded a sequence.
eos_id: An integer, ID of eos token, used to determine when a sequence has
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
Returns: Returns:
Top decoded sequences [batch_size, beam_size, max_decode_length] Top decoded sequences [batch_size, beam_size, max_decode_length]
sequence scores [batch_size, beam_size] sequence scores [batch_size, beam_size]
""" """
batch_size = tf.shape(initial_ids)[0] batch_size = (
initial_ids.shape.as_list()[0] if padded_decode else
tf.shape(initial_ids)[0])
if misc.is_v2(): if misc.is_v2():
sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size, sbs = SequenceBeamSearchV2(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
else: else:
sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, sbs = v1.SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
beam_size, alpha, max_decode_length, eos_id, beam_size, alpha, max_decode_length, eos_id,
dtype) padded_decode, dtype)
return sbs.search(initial_ids, initial_cache) return sbs.search(initial_ids, initial_cache)
......
...@@ -191,6 +191,29 @@ def define_transformer_flags(): ...@@ -191,6 +191,29 @@ def define_transformer_flags():
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test ' 'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.')) 'all use 1VM config.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags_core.set_defaults(data_dir='/tmp/translate_ende', flags_core.set_defaults(data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model', model_dir='/tmp/transformer_model',
......
...@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model): ...@@ -112,11 +112,22 @@ class Transformer(tf.keras.Model):
outputs: [batch_size, decoded length] outputs: [batch_size, decoded length]
scores: [batch_size, float]} scores: [batch_size, float]}
Even when float16 is used, the output tensor(s) are always float32. Even when float16 is used, the output tensor(s) are always float32.
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] inputs, targets = inputs[0], inputs[1]
else: else:
inputs, targets = inputs[0], None inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
raise NotImplementedError(
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
# Variance scaling is used here because it seems to work in many problems. # Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well. # Other reasonable initializers may also work just as well.
...@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model): ...@@ -225,13 +236,14 @@ class Transformer(tf.keras.Model):
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self.params["dtype"]) max_decode_length, dtype=self.params["dtype"])
# TODO(b/139770046): Refactor code with better naming of i.
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape [batch_size * ids: Current decoded sequences. int tensor with shape [batch_size *
beam_size, i + 1] beam_size, i + 1].
i: Loop index i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder cache: dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
...@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model): ...@@ -245,16 +257,29 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal. # Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack( decoder_outputs = self.decoder_stack(
decoder_input, decoder_input,
cache.get("encoder_outputs"), cache.get("encoder_outputs"),
self_attention_bias, self_attention_bias,
cache.get("encoder_decoder_attention_bias"), cache.get("encoder_decoder_attention_bias"),
training=training, training=training,
cache=cache) cache=cache,
decode_loop_step=i if self.params["padded_decode"] else None)
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
return logits, cache return logits, cache
...@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model): ...@@ -263,8 +288,12 @@ class Transformer(tf.keras.Model):
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
"""Return predicted sequence.""" """Return predicted sequence."""
batch_size = tf.shape(encoder_outputs)[0] if self.params["padded_decode"]:
input_length = tf.shape(encoder_outputs)[1] batch_size = encoder_outputs.shape.as_list()[0]
input_length = encoder_outputs.shape.as_list()[1]
else:
batch_size = tf.shape(encoder_outputs)[0]
input_length = tf.shape(encoder_outputs)[1]
max_decode_length = input_length + self.params["extra_decode_length"] max_decode_length = input_length + self.params["extra_decode_length"]
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self.params["dtype"]) self.params["dtype"])
...@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model): ...@@ -277,12 +306,20 @@ class Transformer(tf.keras.Model):
# Create cache storing decoder attention values for each layer. # Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
init_decode_length = (
max_decode_length if self.params["padded_decode"] else 0)
cache = { cache = {
"layer_%d" % layer: { "layer_%d" % layer: {
"k": tf.zeros([batch_size, 0, self.params["hidden_size"]], "k":
dtype=self.params["dtype"]), tf.zeros([
"v": tf.zeros([batch_size, 0, self.params["hidden_size"]], batch_size, init_decode_length, self.params["hidden_size"]
dtype=self.params["dtype"]) ],
dtype=self.params["dtype"]),
"v":
tf.zeros([
batch_size, init_decode_length, self.params["hidden_size"]
],
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"]) } for layer in range(self.params["num_hidden_layers"])
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
...@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model): ...@@ -301,6 +338,7 @@ class Transformer(tf.keras.Model):
alpha=self.params["alpha"], alpha=self.params["alpha"],
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=EOS_ID, eos_id=EOS_ID,
padded_decode=self.params["padded_decode"],
dtype=self.params["dtype"]) dtype=self.params["dtype"])
# Get the top sequence for each batch element # Get the top sequence for each batch element
...@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -505,22 +543,28 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_self_attention_bias, decoder_self_attention_bias,
attention_bias, attention_bias,
training, training,
cache=None): cache=None,
decode_loop_step=None):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
decoder_inputs: tensor with shape [batch_size, target_length, hidden_size] decoder_inputs: A tensor with shape
encoder_outputs: tensor with shape [batch_size, input_length, hidden_size] [batch_size, target_length, hidden_size].
decoder_self_attention_bias: bias for decoder self-attention layer. [1, 1, encoder_outputs: A tensor with shape
target_len, target_length] [batch_size, input_length, hidden_size]
attention_bias: bias for encoder-decoder attention layer. [batch_size, 1, decoder_self_attention_bias: A tensor with shape
1, input_length] [1, 1, target_len, target_length], the bias for decoder self-attention
training: boolean, whether in training mode or not. layer.
attention_bias: A tensor with shape [batch_size, 1, 1, input_length],
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": tensor with shape [batch_size, i, key_channels], {layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
"v": tensor with shape [batch_size, i, value_channels]}, "v": A tensor with shape [batch_size, i, value_channels]},
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU.
Returns: Returns:
Output of decoder layer stack. Output of decoder layer stack.
...@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -540,7 +584,8 @@ class DecoderStack(tf.keras.layers.Layer):
decoder_inputs, decoder_inputs,
decoder_self_attention_bias, decoder_self_attention_bias,
training=training, training=training,
cache=layer_cache) cache=layer_cache,
decode_loop_step=decode_loop_step)
with tf.name_scope("encdec_attention"): with tf.name_scope("encdec_attention"):
decoder_inputs = enc_dec_attention_layer( decoder_inputs = enc_dec_attention_layer(
decoder_inputs, decoder_inputs,
......
...@@ -52,18 +52,40 @@ BLEU_DIR = "bleu" ...@@ -52,18 +52,40 @@ BLEU_DIR = "bleu"
_SINGLE_SAMPLE = 1 _SINGLE_SAMPLE = 1
def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref): def translate_and_compute_bleu(model,
"""Translate file and report the cased and uncased bleu scores.""" params,
subtokenizer,
bleu_source,
bleu_ref,
distribution_strategy=None):
"""Translate file and report the cased and uncased bleu scores.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
subtokenizer: A subtokenizer object, used for encoding and decoding source
and translated lines.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
# Create temporary file to store translation. # Create temporary file to store translation.
tmp = tempfile.NamedTemporaryFile(delete=False) tmp = tempfile.NamedTemporaryFile(delete=False)
tmp_filename = tmp.name tmp_filename = tmp.name
translate.translate_file( translate.translate_file(
model, model,
params,
subtokenizer, subtokenizer,
bleu_source, bleu_source,
output_file=tmp_filename, output_file=tmp_filename,
print_all_translations=False) print_all_translations=False,
distribution_strategy=distribution_strategy)
# Compute uncased and cased bleu scores. # Compute uncased and cased bleu scores.
uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False) uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
...@@ -72,12 +94,31 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref): ...@@ -72,12 +94,31 @@ def translate_and_compute_bleu(model, subtokenizer, bleu_source, bleu_ref):
return uncased_score, cased_score return uncased_score, cased_score
def evaluate_and_log_bleu(model, bleu_source, bleu_ref, vocab_file): def evaluate_and_log_bleu(model,
"""Calculate and record the BLEU score.""" params,
bleu_source,
bleu_ref,
vocab_file,
distribution_strategy=None):
"""Calculate and record the BLEU score.
Args:
model: A Keras model, used to generate the translations.
params: A dictionary, containing the translation related parameters.
bleu_source: A file containing source sentences for translation.
bleu_ref: A file containing the reference for the translated sentences.
vocab_file: A file containing the vocabulary for translation.
distribution_strategy: A platform distribution strategy, used for TPU based
translation.
Returns:
uncased_score: A float, the case insensitive BLEU score.
cased_score: A float, the case sensitive BLEU score.
"""
subtokenizer = tokenizer.Subtokenizer(vocab_file) subtokenizer = tokenizer.Subtokenizer(vocab_file)
uncased_score, cased_score = translate_and_compute_bleu( uncased_score, cased_score = translate_and_compute_bleu(
model, subtokenizer, bleu_source, bleu_ref) model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy)
logging.info("Bleu score (uncased): %s", uncased_score) logging.info("Bleu score (uncased): %s", uncased_score)
logging.info("Bleu score (cased): %s", cased_score) logging.info("Bleu score (cased): %s", cased_score)
...@@ -110,6 +151,9 @@ class TransformerTask(object): ...@@ -110,6 +151,9 @@ class TransformerTask(object):
params["model_dir"] = flags_obj.model_dir params["model_dir"] = flags_obj.model_dir
params["static_batch"] = flags_obj.static_batch params["static_batch"] = flags_obj.static_batch
params["max_length"] = flags_obj.max_length params["max_length"] = flags_obj.max_length
params["decode_batch_size"] = flags_obj.decode_batch_size
params["decode_max_length"] = flags_obj.decode_max_length
params["padded_decode"] = flags_obj.padded_decode
params["num_parallel_calls"] = ( params["num_parallel_calls"] = (
flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE) flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
...@@ -133,6 +177,7 @@ class TransformerTask(object): ...@@ -133,6 +177,7 @@ class TransformerTask(object):
num_gpus=num_gpus, num_gpus=num_gpus,
tpu_address=flags_obj.tpu or "") tpu_address=flags_obj.tpu or "")
if self.use_tpu: if self.use_tpu:
params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
if not params["static_batch"]: if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.") raise ValueError("TPU requires static batch for input data.")
else: else:
...@@ -306,10 +351,10 @@ class TransformerTask(object): ...@@ -306,10 +351,10 @@ class TransformerTask(object):
self.predict_model, self.predict_model,
tf.train.latest_checkpoint(self.flags_obj.model_dir)) tf.train.latest_checkpoint(self.flags_obj.model_dir))
self.predict_model.summary() self.predict_model.summary()
return evaluate_and_log_bleu(self.predict_model, return evaluate_and_log_bleu(
self.flags_obj.bleu_source, self.predict_model, self.params, self.flags_obj.bleu_source,
self.flags_obj.bleu_ref, self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
self.flags_obj.vocab_file) self.distribution_strategy if self.use_tpu else None)
def predict(self): def predict(self):
"""Predicts result from the model.""" """Predicts result from the model."""
......
...@@ -18,11 +18,12 @@ from __future__ import absolute_import ...@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import values
from official.transformer.utils import tokenizer from official.transformer.utils import tokenizer
_DECODE_BATCH_SIZE = 32
_EXTRA_DECODE_LENGTH = 100 _EXTRA_DECODE_LENGTH = 100
_BEAM_SIZE = 4 _BEAM_SIZE = 4
_ALPHA = 0.6 _ALPHA = 0.6
...@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer): ...@@ -68,23 +69,31 @@ def _trim_and_decode(ids, subtokenizer):
return subtokenizer.decode(ids) return subtokenizer.decode(ids)
def translate_file( def translate_file(model,
model, subtokenizer, input_file, output_file=None, params,
print_all_translations=True): subtokenizer,
input_file,
output_file=None,
print_all_translations=True,
distribution_strategy=None):
"""Translate lines in file, and save to output file if specified. """Translate lines in file, and save to output file if specified.
Args: Args:
model: Keras model used to generate the translations. model: A Keras model, used to generate the translations.
subtokenizer: Subtokenizer object for encoding and decoding source and params: A dictionary, containing the translation related parameters.
translated lines. subtokenizer: A subtokenizer object, used for encoding and decoding source
input_file: file containing lines to translate and translated lines.
output_file: file that stores the generated translations. input_file: A file containing lines to translate.
print_all_translations: If true, all translations are printed to stdout. output_file: A file that stores the generated translations.
print_all_translations: A bool. If true, all translations are printed to
stdout.
distribution_strategy: A distribution strategy, used to perform inference
directly with tf.function instead of Keras model.predict().
Raises: Raises:
ValueError: if output file is invalid. ValueError: if output file is invalid.
""" """
batch_size = _DECODE_BATCH_SIZE batch_size = params["decode_batch_size"]
# Read and sort inputs by length. Keep dictionary (original index-->new index # Read and sort inputs by length. Keep dictionary (original index-->new index
# in sorted list) to write translations in the original order. # in sorted list) to write translations in the original order.
...@@ -101,24 +110,59 @@ def translate_file( ...@@ -101,24 +110,59 @@ def translate_file(
if j + i * batch_size < total_samples if j + i * batch_size < total_samples
] ]
lines = [_encode_and_add_eos(l, subtokenizer) for l in lines] lines = [_encode_and_add_eos(l, subtokenizer) for l in lines]
if distribution_strategy:
for j in range(batch_size - len(lines)):
lines.append([tokenizer.EOS_ID])
batch = tf.keras.preprocessing.sequence.pad_sequences( batch = tf.keras.preprocessing.sequence.pad_sequences(
lines, dtype="int64", padding="post") lines,
maxlen=params["decode_max_length"],
dtype="int32",
padding="post")
tf.compat.v1.logging.info("Decoding batch %d out of %d.", i, tf.compat.v1.logging.info("Decoding batch %d out of %d.", i,
num_decode_batches) num_decode_batches)
yield batch yield batch
@tf.function
def predict_step(inputs):
"""Decoding step function for TPU runs."""
def _step_fn(inputs):
"""Per replica step function."""
val_outputs, _ = model([inputs], training=False)
return val_outputs
return distribution_strategy.experimental_run_v2(_step_fn, args=(inputs,))
translations = [] translations = []
if distribution_strategy:
num_replicas = distribution_strategy.num_replicas_in_sync
local_batch_size = params["decode_batch_size"] // num_replicas
for i, text in enumerate(input_generator()): for i, text in enumerate(input_generator()):
val_outputs, _ = model.predict(text) if distribution_strategy:
text = np.reshape(text, [num_replicas, local_batch_size, -1])
text = [
tf.convert_to_tensor(per_replica_text) for per_replica_text in text
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
# pylint: enable=protected-access
val_outputs = distribution_strategy.experimental_local_results(
predict_step(text))
val_outputs = np.reshape(
[val_output.numpy() for val_output in val_outputs],
[params["decode_batch_size"], -1])
else:
val_outputs, _ = model.predict(text)
length = len(val_outputs) length = len(val_outputs)
for j in range(length): for j in range(length):
translation = _trim_and_decode(val_outputs[j], subtokenizer) if j + i * batch_size < total_samples:
translations.append(translation) translation = _trim_and_decode(val_outputs[j], subtokenizer)
if print_all_translations: translations.append(translation)
tf.compat.v1.logging.info( if print_all_translations:
"Translating:\n\tInput: %s\n\tOutput: %s" % tf.compat.v1.logging.info(
(sorted_inputs[j + i * batch_size], translation)) "Translating:\n\tInput: %s\n\tOutput: %s" %
(sorted_inputs[j + i * batch_size], translation))
# Write translations in the order they appeared in the original file. # Write translations in the order they appeared in the original file.
if output_file is not None: if output_file is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment