Commit 1180f37e authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[nlp][translation] Remove seq2seq model _dtype argument and break transformer utils dependency.

PiperOrigin-RevId: 363121226
parent f06dc1a6
...@@ -23,10 +23,8 @@ from official.modeling import tf_utils ...@@ -23,10 +23,8 @@ from official.modeling import tf_utils
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import model_utils
EOS_ID = 1 EOS_ID = 1
# pylint: disable=g-classes-have-attributes
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -52,7 +50,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -52,7 +50,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha=0.6, alpha=0.6,
encoder_layer=None, encoder_layer=None,
decoder_layer=None, decoder_layer=None,
dtype=tf.float32,
eos_id=EOS_ID, eos_id=EOS_ID,
**kwargs): **kwargs):
"""Initialize layers to build Transformer model. """Initialize layers to build Transformer model.
...@@ -69,7 +66,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -69,7 +66,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha: The strength of length normalization for beam search. alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer. encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer. decoder_layer: An initialized decoder layer.
dtype: float dtype.
eos_id: Id of end of sentence token. eos_id: Id of end of sentence token.
**kwargs: other keyword arguments. **kwargs: other keyword arguments.
""" """
...@@ -82,7 +78,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -82,7 +78,6 @@ class Seq2SeqTransformer(tf.keras.Model):
self._extra_decode_length = extra_decode_length self._extra_decode_length = extra_decode_length
self._beam_size = beam_size self._beam_size = beam_size
self._alpha = alpha self._alpha = alpha
self._dtype = dtype
self._eos_id = eos_id self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding( self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
...@@ -104,7 +99,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -104,7 +99,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"dropout_rate": self._dropout_rate, "dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode, "padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length, "decode_max_length": self._decode_max_length,
"dtype": self._dtype,
"eos_id": self._eos_id, "eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length, "extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size, "beam_size": self._beam_size,
...@@ -123,10 +117,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -123,10 +117,7 @@ class Seq2SeqTransformer(tf.keras.Model):
vocab_size = tf.shape(embedding_matrix)[0] vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size]) x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul( logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True)
tf.cast(x, dtype=self._dtype),
tf.cast(embedding_matrix, self._dtype),
transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size]) return tf.reshape(logits, [batch_size, length, vocab_size])
...@@ -154,14 +145,10 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -154,14 +145,10 @@ class Seq2SeqTransformer(tf.keras.Model):
""" """
sources = inputs["inputs"] sources = inputs["inputs"]
targets = inputs.get("targets", None) targets = inputs.get("targets", None)
attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
# applying dropout. # applying dropout.
embedded_inputs = self.embedding_lookup(sources) embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast( embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype)
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1) embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation. # Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2) input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
...@@ -173,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -173,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype) shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(inputs=embedded_inputs) pos_encoding = self.position_embedding(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs) encoder_inputs = self.encoder_dropout(encoder_inputs)
...@@ -183,15 +170,11 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -183,15 +170,11 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_inputs, attention_mask=attention_mask) encoder_inputs, attention_mask=attention_mask)
if targets is None: if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode: if self._padded_decode:
max_decode_length = self._decode_max_length max_decode_length = self._decode_max_length
else: else:
max_decode_length = self._decode_max_length or ( max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length) tf.shape(encoder_outputs)[1] + self._extra_decode_length)
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length) symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0] batch_size = tf.shape(encoder_outputs)[0]
...@@ -199,28 +182,35 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -199,28 +182,35 @@ class Seq2SeqTransformer(tf.keras.Model):
initial_ids = tf.zeros([batch_size], dtype=tf.int32) initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer. # Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (max_decode_length if self._padded_decode else 0) init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads dim_per_head = self._embedding_width // num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache = { cache = {
str(layer): { str(layer): {
"key": "key":
tf.zeros( tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head], [batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype), dtype=self.compute_dtype),
"value": "value":
tf.zeros( tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head], [batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype) dtype=self.compute_dtype)
} for layer in range(self.decoder_layer.num_layers) } for layer in range(self.decoder_layer.num_layers)
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache. # Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype
)
cache["encoder_outputs"] = encoder_outputs cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias cache["encoder_decoder_attention_mask"] = attention_mask
# Use beam search to find the top beam_size sequences and scores. # Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search( decoded_ids, scores = beam_search.sequence_beam_search(
...@@ -233,7 +223,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -233,7 +223,7 @@ class Seq2SeqTransformer(tf.keras.Model):
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=self._eos_id, eos_id=self._eos_id,
padded_decode=self._padded_decode, padded_decode=self._padded_decode,
dtype=self._dtype) dtype=self.compute_dtype)
# Get the top sequence for each batch element # Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:] top_decoded_ids = decoded_ids[:, 0, 1:]
...@@ -242,15 +232,13 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -242,15 +232,13 @@ class Seq2SeqTransformer(tf.keras.Model):
return {"outputs": top_decoded_ids, "scores": top_scores} return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets) decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast( embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1) decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
decoder_inputs += pos_encoding decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs) decoder_inputs = self.decoder_dropout(decoder_inputs)
...@@ -259,8 +247,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -259,8 +247,7 @@ class Seq2SeqTransformer(tf.keras.Model):
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part( self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
tf.ones([length, length], dtype=tf.float32), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length]) self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
...@@ -274,6 +261,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -274,6 +261,8 @@ class Seq2SeqTransformer(tf.keras.Model):
memory_mask=self_attention_mask, memory_mask=self_attention_mask,
target_mask=attention_mask) target_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits = tf.cast(logits, tf.float32) logits = tf.cast(logits, tf.float32)
return logits return logits
...@@ -281,9 +270,12 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -281,9 +270,12 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens.""" """Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding( timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1) inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype) timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_mask = tf.linalg.band_part(
max_decode_length, dtype=self._dtype) tf.ones([max_decode_length, max_decode_length],
dtype=self.compute_dtype), -1, 0)
decoder_self_attention_mask = tf.reshape(
decoder_self_attention_mask, [1, max_decode_length, max_decode_length])
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
...@@ -308,33 +300,24 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -308,33 +300,24 @@ class Seq2SeqTransformer(tf.keras.Model):
source_decoder_input = decoder_input source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input) decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
self.embedding_lookup.embeddings.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1) decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i] decoder_input += timing_signal[i]
if self._padded_decode: if self._padded_decode:
bias_shape = decoder_self_attention_bias.shape.as_list() # indexing does not work on TPU.
self_attention_bias = tf.slice( bias_shape = decoder_self_attention_mask.shape.as_list()
decoder_self_attention_bias, [0, 0, i, 0], self_attention_mask = tf.slice(
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else: else:
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] self_attention_mask = decoder_self_attention_mask[:, i:i+1, :i+1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
attention_bias = cache.get("encoder_decoder_attention_bias") self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_bias = tf.where(attention_bias < 0, attention_mask = cache.get("encoder_decoder_attention_mask")
tf.zeros_like(attention_bias), attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
tf.ones_like(attention_bias))
attention_bias = tf.squeeze(attention_bias, axis=[1])
attention_mask = tf.tile(attention_bias, [1, decoder_length, 1])
self_attention_bias = tf.where(self_attention_bias < 0,
tf.zeros_like(self_attention_bias),
tf.ones_like(self_attention_bias))
self_attention_bias = tf.squeeze(self_attention_bias, axis=[1])
self_attention_mask = tf.tile(self_attention_bias, [batch_size, 1, 1])
decoder_outputs = self.decoder_layer( decoder_outputs = self.decoder_layer(
decoder_input, decoder_input,
...@@ -344,6 +327,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -344,6 +327,7 @@ class Seq2SeqTransformer(tf.keras.Model):
cache=cache, cache=cache,
decode_loop_step=i if self._padded_decode else None) decode_loop_step=i if self._padded_decode else None)
decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
logits = self._embedding_linear(self.embedding_lookup.embeddings, logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs) decoder_outputs)
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
...@@ -359,21 +343,6 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -359,21 +343,6 @@ class TransformerEncoder(tf.keras.layers.Layer):
of the sublayers: of the sublayers:
1. Self-attention layer 1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers) 2. Feedforward network (which is 2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -388,6 +357,25 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -388,6 +357,25 @@ class TransformerEncoder(tf.keras.layers.Layer):
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs) super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -469,21 +457,6 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -469,21 +457,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
2. Multi-headed attention layer combining encoder outputs with results from 2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer. the previous self-attention layer.
3. Feedforward network (2 fully-connected layers) 3. Feedforward network (2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set `False`, output of attention and intermediate dense layers
is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -498,6 +471,24 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -498,6 +471,24 @@ class TransformerDecoder(tf.keras.layers.Layer):
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate
dense layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs) super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment