Commit f3641f23 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Change Seq2SeqTransformer inputs to dictionary.

PiperOrigin-RevId: 338309524
parent 09b9dad7
...@@ -130,9 +130,9 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -130,9 +130,9 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Calculate target logits or inferred target sequences. """Calculate target logits or inferred target sequences.
Args: Args:
inputs: input tensor list of size 1 or 2. inputs: a dictionary of tensors.
First item, inputs: int tensor with shape [batch_size, input_length]. Feature `inputs`: int tensor with shape [batch_size, input_length].
Second item (optional), targets: None or int tensor with shape Feature `targets` (optional): None or int tensor with shape
[batch_size, target_length]. [batch_size, target_length].
Returns: Returns:
...@@ -147,12 +147,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -147,12 +147,8 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
inputs = inputs if isinstance(inputs, list) else [inputs] sources = inputs["inputs"]
if len(inputs) == 2: targets = inputs.get("targets", None)
sources, targets = inputs[0], inputs[1]
else:
# Decoding path.
sources, targets = inputs[0], None
attention_bias = model_utils.get_padding_bias(sources) attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype) attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
......
...@@ -82,15 +82,15 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -82,15 +82,15 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
return tf.nest.map_structure(distribution.experimental_local_results, return tf.nest.map_structure(distribution.experimental_local_results,
outputs) outputs)
fake_inputs = [np.zeros((batch_size, decode_max_length), dtype=np.int32)] fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
local_outputs = step(fake_inputs) local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs) logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10)) self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
fake_inputs = [ fake_inputs = dict(
np.zeros((batch_size, decode_max_length), dtype=np.int32), inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32),
np.zeros((batch_size, 8), dtype=np.int32) targets=np.zeros((batch_size, 8), dtype=np.int32))
]
local_outputs = step(fake_inputs) local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs) logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100)) self.assertEqual(local_outputs[0].shape, (4, 8, 100))
...@@ -108,7 +108,7 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -108,7 +108,7 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
@tf.function @tf.function
def serve(self, inputs): def serve(self, inputs):
return self.model.call([inputs]) return self.model.call(dict(inputs=inputs))
save_module = SaveModule(model) save_module = SaveModule(model)
if padded_decode: if padded_decode:
......
...@@ -70,7 +70,8 @@ def _create_model(params, is_train): ...@@ -70,7 +70,8 @@ def _create_model(params, is_train):
inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs") inputs = tf.keras.layers.Input((None,), dtype="int64", name="inputs")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets") targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
internal_model = models.Seq2SeqTransformer(**model_kwargs) internal_model = models.Seq2SeqTransformer(**model_kwargs)
logits = internal_model([inputs, targets], training=is_train) logits = internal_model(
dict(inputs=inputs, targets=targets), training=is_train)
vocab_size = params["vocab_size"] vocab_size = params["vocab_size"]
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
...@@ -90,7 +91,7 @@ def _create_model(params, is_train): ...@@ -90,7 +91,7 @@ def _create_model(params, is_train):
dtype="int64", dtype="int64",
name="inputs") name="inputs")
internal_model = models.Seq2SeqTransformer(**model_kwargs) internal_model = models.Seq2SeqTransformer(**model_kwargs)
ret = internal_model([inputs], training=is_train) ret = internal_model(dict(inputs=inputs), training=is_train)
outputs, scores = ret["outputs"], ret["scores"] outputs, scores = ret["outputs"], ret["scores"]
return tf.keras.Model(inputs, [outputs, scores]) return tf.keras.Model(inputs, [outputs, scores])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment