Commit f79858bf authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 340580428
parent b63c955f
...@@ -306,21 +306,14 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -306,21 +306,14 @@ class Seq2SeqTransformer(tf.keras.Model):
tf.not_equal(source_decoder_input, 0), tf.not_equal(source_decoder_input, 0),
self.embedding_lookup.embeddings.dtype) self.embedding_lookup.embeddings.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1) decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i]
if self._padded_decode: if self._padded_decode:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list() bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice( self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0], decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else: else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
......
...@@ -253,19 +253,13 @@ class Transformer(tf.keras.Model): ...@@ -253,19 +253,13 @@ class Transformer(tf.keras.Model):
# Preprocess decoder input by getting embeddings and adding timing signal. # Preprocess decoder input by getting embeddings and adding timing signal.
decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input = self.embedding_softmax_layer(decoder_input)
decoder_input += timing_signal[i]
if self.params["padded_decode"]: if self.params["padded_decode"]:
timing_signal_shape = timing_signal.shape.as_list()
decoder_input += tf.slice(timing_signal, [i, 0],
[1, timing_signal_shape[1]])
bias_shape = decoder_self_attention_bias.shape.as_list() bias_shape = decoder_self_attention_bias.shape.as_list()
self_attention_bias = tf.slice( self_attention_bias = tf.slice(
decoder_self_attention_bias, [0, 0, i, 0], decoder_self_attention_bias, [0, 0, i, 0],
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
else: else:
decoder_input += timing_signal[i:i + 1]
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
decoder_outputs = self.decoder_stack( decoder_outputs = self.decoder_stack(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment