Commit c4ebfef2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406906304
parent 460890ed
...@@ -103,7 +103,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -103,7 +103,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"beam_size": self._beam_size, "beam_size": self._beam_size,
"alpha": self._alpha, "alpha": self._alpha,
"encoder_layer": self.encoder_layer, "encoder_layer": self.encoder_layer,
"decoder_layer": self.decoder_layer "decoder_layer": self.decoder_layer,
} }
base_config = super(Seq2SeqTransformer, self).get_config() base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -122,14 +122,47 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -122,14 +122,47 @@ class Seq2SeqTransformer(tf.keras.Model):
return tf.reshape(logits, [batch_size, length, vocab_size]) return tf.reshape(logits, [batch_size, length, vocab_size])
def _parse_inputs(self, inputs):
"""Parses the `call` inputs and returns an uniformed output."""
sources = inputs.get("inputs", None)
input_mask = inputs.get("input_masks", None)
embedded = inputs.get("embedded_inputs", None)
if sources is None and embedded is not None:
embedded_inputs = embedded
boolean_mask = input_mask
input_shape = tf_utils.get_shape_list(embedded, expected_rank=3)
source_dtype = embedded.dtype
elif sources is not None:
embedded_inputs = self.embedding_lookup(sources)
boolean_mask = tf.not_equal(sources, 0)
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
source_dtype = sources.dtype
else:
raise KeyError(
"The call method expects either `inputs` or `embedded_inputs` and "
"`input_masks` as input features.")
return embedded_inputs, boolean_mask, input_shape, source_dtype
def call(self, inputs): def call(self, inputs):
"""Calculate target logits or inferred target sequences. """Calculate target logits or inferred target sequences.
Args: Args:
inputs: a dictionary of tensors. inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`. Feature `inputs` (optional): int tensor with shape
`[batch_size, input_length]`.
Feature `embedded_inputs` (optional): float tensor with shape
`[batch_size, input_length, embedding_width]`.
Feature `targets` (optional): None or int tensor with shape Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`. `[batch_size, target_length]`.
Feature `input_masks` (optional): When providing the `embedded_inputs`,
the dictionary must provide a boolean mask marking the filled time
steps. The shape of the tensor is `[batch_size, input_length]`.
Either `inputs` or `embedded_inputs` and `input_masks` must be present
in the input dictionary. In the second case the projection of the
integer tokens to the transformer embedding space is skipped and
`input_masks` is expected to be present.
Returns: Returns:
If targets is defined, then return logits for each word in the target If targets is defined, then return logits for each word in the target
...@@ -144,21 +177,19 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -144,21 +177,19 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
sources = inputs["inputs"]
targets = inputs.get("targets", None)
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
# applying dropout. # applying dropout.
embedded_inputs = self.embedding_lookup(sources) targets = inputs.get("targets", None)
embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype) (embedded_inputs, boolean_mask,
input_shape, source_dtype) = self._parse_inputs(inputs)
embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1) embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation. # Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast( attention_mask = tf.cast(
tf.reshape( tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]), dtype=source_dtype)
dtype=sources.dtype)
broadcast_ones = tf.ones( broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype) shape=[input_shape[0], input_shape[1], 1], dtype=source_dtype)
attention_mask = broadcast_ones * attention_mask attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(embedded_inputs) pos_encoding = self.position_embedding(embedded_inputs)
...@@ -206,8 +237,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -206,8 +237,7 @@ class Seq2SeqTransformer(tf.keras.Model):
# Add encoder output and attention bias to the cache. # Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype) encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast( attention_mask = tf.cast(
tf.reshape( tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype) dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask cache["encoder_decoder_attention_mask"] = attention_mask
...@@ -252,7 +282,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -252,7 +282,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast( attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype) tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1]) attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer( outputs = self.decoder_layer(
......
...@@ -26,12 +26,11 @@ from official.nlp.modeling.models import seq2seq_transformer ...@@ -26,12 +26,11 @@ from official.nlp.modeling.models import seq2seq_transformer
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def _build_model(self, padded_decode, decode_max_length): def _build_model(self, padded_decode, decode_max_length, embedding_width):
num_layers = 1 num_layers = 1
num_attention_heads = 2 num_attention_heads = 2
intermediate_size = 32 intermediate_size = 32
vocab_size = 100 vocab_size = 100
embedding_width = 16
encdec_kwargs = dict( encdec_kwargs = dict(
num_layers=num_layers, num_layers=num_layers,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -63,15 +62,19 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -63,15 +62,19 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations.default_strategy, strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy, strategy_combinations.cloud_tpu_strategy,
], ],
embed=[True, False],
is_training=[True, False],
mode="eager")) mode="eager"))
def test_create_model_with_ds(self, distribution): def test_create_model_with_ds(self, distribution, embed, is_training):
with distribution.scope(): with distribution.scope():
padded_decode = isinstance( padded_decode = isinstance(
distribution, distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)) (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10 decode_max_length = 10
batch_size = 4 batch_size = 4
model = self._build_model(padded_decode, decode_max_length) embedding_width = 16
model = self._build_model(
padded_decode, decode_max_length, embedding_width)
@tf.function @tf.function
def step(inputs): def step(inputs):
...@@ -83,23 +86,32 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -83,23 +86,32 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
return tf.nest.map_structure(distribution.experimental_local_results, return tf.nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
fake_inputs = dict( if embed:
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32)) fake_inputs = dict(
local_outputs = step(fake_inputs) embedded_inputs=np.zeros(
logging.info("local_outputs=%s", local_outputs) (batch_size, decode_max_length, embedding_width),
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10)) dtype=np.float32),
input_masks=np.ones((batch_size, decode_max_length), dtype=np.bool))
fake_inputs = dict( else:
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32), fake_inputs = dict(
targets=np.zeros((batch_size, 8), dtype=np.int32)) inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs) if is_training:
self.assertEqual(local_outputs[0].shape, (4, 8, 100)) fake_inputs["targets"] = np.zeros((batch_size, 8), dtype=np.int32)
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
else:
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
@parameterized.parameters(True, False) @parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode): def test_create_savedmodel(self, padded_decode):
decode_max_length = 10 decode_max_length = 10
model = self._build_model(padded_decode, decode_max_length) embedding_width = 16
model = self._build_model(
padded_decode, decode_max_length, embedding_width)
class SaveModule(tf.Module): class SaveModule(tf.Module):
...@@ -111,14 +123,28 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -111,14 +123,28 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def serve(self, inputs): def serve(self, inputs):
return self.model.call(dict(inputs=inputs)) return self.model.call(dict(inputs=inputs))
@tf.function
def embedded_serve(self, embedded_inputs, input_masks):
return self.model.call(
dict(embedded_inputs=embedded_inputs, input_masks=input_masks))
save_module = SaveModule(model) save_module = SaveModule(model)
if padded_decode: if padded_decode:
tensor_shape = (4, 10) tensor_shape = (4, decode_max_length)
embedded_tensor_shape = (4, decode_max_length, embedding_width)
else: else:
tensor_shape = (None, None) tensor_shape = (None, None)
embedded_tensor_shape = (None, None, embedding_width)
signatures = dict( signatures = dict(
serving_default=save_module.serve.get_concrete_function( serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs"))) tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")),
embedded_serving=save_module.embedded_serve.get_concrete_function(
tf.TensorSpec(
shape=embedded_tensor_shape, dtype=tf.float32,
name="embedded_inputs"),
tf.TensorSpec(
shape=tensor_shape, dtype=tf.bool, name="input_masks"),
))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures) tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment