Commit ab70b0f2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406906304
parent c3becc45
......@@ -103,7 +103,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"beam_size": self._beam_size,
"alpha": self._alpha,
"encoder_layer": self.encoder_layer,
"decoder_layer": self.decoder_layer
"decoder_layer": self.decoder_layer,
}
base_config = super(Seq2SeqTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -122,14 +122,47 @@ class Seq2SeqTransformer(tf.keras.Model):
return tf.reshape(logits, [batch_size, length, vocab_size])
def _parse_inputs(self, inputs):
"""Parses the `call` inputs and returns an uniformed output."""
sources = inputs.get("inputs", None)
input_mask = inputs.get("input_masks", None)
embedded = inputs.get("embedded_inputs", None)
if sources is None and embedded is not None:
embedded_inputs = embedded
boolean_mask = input_mask
input_shape = tf_utils.get_shape_list(embedded, expected_rank=3)
source_dtype = embedded.dtype
elif sources is not None:
embedded_inputs = self.embedding_lookup(sources)
boolean_mask = tf.not_equal(sources, 0)
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
source_dtype = sources.dtype
else:
raise KeyError(
"The call method expects either `inputs` or `embedded_inputs` and "
"`input_masks` as input features.")
return embedded_inputs, boolean_mask, input_shape, source_dtype
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
Args:
inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `inputs` (optional): int tensor with shape
`[batch_size, input_length]`.
Feature `embedded_inputs` (optional): float tensor with shape
`[batch_size, input_length, embedding_width]`.
Feature `targets` (optional): None or int tensor with shape
`[batch_size, target_length]`.
Feature `input_masks` (optional): When providing the `embedded_inputs`,
the dictionary must provide a boolean mask marking the filled time
steps. The shape of the tensor is `[batch_size, input_length]`.
Either `inputs` or `embedded_inputs` and `input_masks` must be present
in the input dictionary. In the second case the projection of the
integer tokens to the transformer embedding space is skipped and
`input_masks` is expected to be present.
Returns:
If targets is defined, then return logits for each word in the target
......@@ -144,21 +177,19 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs.
"""
sources = inputs["inputs"]
targets = inputs.get("targets", None)
# Prepare inputs to the layer stack by adding positional encodings and
# applying dropout.
embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype)
targets = inputs.get("targets", None)
(embedded_inputs, boolean_mask,
input_shape, source_dtype) = self._parse_inputs(inputs)
embedding_mask = tf.cast(boolean_mask, embedded_inputs.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=sources.dtype)
tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
dtype=source_dtype)
broadcast_ones = tf.ones(
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
shape=[input_shape[0], input_shape[1], 1], dtype=source_dtype)
attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(embedded_inputs)
......@@ -206,8 +237,7 @@ class Seq2SeqTransformer(tf.keras.Model):
# Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
tf.reshape(boolean_mask, [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype)
cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_mask"] = attention_mask
......@@ -252,7 +282,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_mask = tf.cast(
tf.expand_dims(tf.not_equal(sources, 0), axis=1), dtype=sources.dtype)
tf.expand_dims(boolean_mask, axis=1), dtype=source_dtype)
attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
outputs = self.decoder_layer(
......
......@@ -26,12 +26,11 @@ from official.nlp.modeling.models import seq2seq_transformer
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def _build_model(self, padded_decode, decode_max_length):
def _build_model(self, padded_decode, decode_max_length, embedding_width):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
embedding_width = 16
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
......@@ -63,15 +62,19 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
embed=[True, False],
is_training=[True, False],
mode="eager"))
def test_create_model_with_ds(self, distribution):
def test_create_model_with_ds(self, distribution, embed, is_training):
with distribution.scope():
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10
batch_size = 4
model = self._build_model(padded_decode, decode_max_length)
embedding_width = 16
model = self._build_model(
padded_decode, decode_max_length, embedding_width)
@tf.function
def step(inputs):
......@@ -83,23 +86,32 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
if embed:
fake_inputs = dict(
embedded_inputs=np.zeros(
(batch_size, decode_max_length, embedding_width),
dtype=np.float32),
input_masks=np.ones((batch_size, decode_max_length), dtype=np.bool))
else:
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
targets=np.zeros((batch_size, 8), dtype=np.int32))
if is_training:
fake_inputs["targets"] = np.zeros((batch_size, 8), dtype=np.int32)
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
else:
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
@parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode):
decode_max_length = 10
model = self._build_model(padded_decode, decode_max_length)
embedding_width = 16
model = self._build_model(
padded_decode, decode_max_length, embedding_width)
class SaveModule(tf.Module):
......@@ -111,14 +123,28 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def serve(self, inputs):
return self.model.call(dict(inputs=inputs))
@tf.function
def embedded_serve(self, embedded_inputs, input_masks):
return self.model.call(
dict(embedded_inputs=embedded_inputs, input_masks=input_masks))
save_module = SaveModule(model)
if padded_decode:
tensor_shape = (4, 10)
tensor_shape = (4, decode_max_length)
embedded_tensor_shape = (4, decode_max_length, embedding_width)
else:
tensor_shape = (None, None)
embedded_tensor_shape = (None, None, embedding_width)
signatures = dict(
serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")))
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")),
embedded_serving=save_module.embedded_serve.get_concrete_function(
tf.TensorSpec(
shape=embedded_tensor_shape, dtype=tf.float32,
name="embedded_inputs"),
tf.TensorSpec(
shape=tensor_shape, dtype=tf.bool, name="input_masks"),
))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment