Commit 54832af8 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix Transformer inference savedmodel.

Keras model serialization causes a lot of problems.

PiperOrigin-RevId: 328162551
parent 7f824479
...@@ -207,12 +207,12 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -207,12 +207,12 @@ class Seq2SeqTransformer(tf.keras.Model):
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
inputs = inputs if isinstance(inputs, list) else [inputs]
if len(inputs) == 2: if len(inputs) == 2:
sources, targets = inputs[0], inputs[1] sources, targets = inputs[0], inputs[1]
else: else:
# Decoding path. # Decoding path.
sources, targets = inputs[0], None sources, targets = inputs[0], None
attention_bias = model_utils.get_padding_bias(sources) attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype) attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
...@@ -245,17 +245,15 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -245,17 +245,15 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_decoder_attention_bias = attention_bias encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype) encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode: if self._padded_decode:
batch_size = encoder_outputs.shape.as_list()[0]
max_decode_length = self._decode_max_length max_decode_length = self._decode_max_length
else: else:
batch_size = tf.shape(encoder_outputs)[0]
max_decode_length = self._decode_max_length or ( max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length) tf.shape(encoder_outputs)[1] + self._extra_decode_length)
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype) self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length) symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0]
# Create initial set of IDs that will be passed to symbols_to_logits_fn. # Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32) initial_ids = tf.zeros([batch_size], dtype=tf.int32)
......
...@@ -129,6 +129,31 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -129,6 +129,31 @@ class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
logging.info("local_outputs=%s", local_outputs) logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100)) self.assertEqual(local_outputs[0].shape, (4, 8, 100))
@parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode):
decode_max_length = 10
model = self._build_model(padded_decode, decode_max_length)
class SaveModule(tf.Module):
def __init__(self, model):
super(SaveModule, self).__init__()
self.model = model
@tf.function
def serve(self, inputs):
return self.model.call([inputs])
save_module = SaveModule(model)
if padded_decode:
tensor_shape = (4, 10)
else:
tensor_shape = (None, None)
signatures = dict(
serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -119,6 +119,7 @@ class Transformer(tf.keras.Model): ...@@ -119,6 +119,7 @@ class Transformer(tf.keras.Model):
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
inputs = inputs if isinstance(inputs, list) else [inputs]
if len(inputs) == 2: if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1] inputs, targets = inputs[0], inputs[1]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment