Commit 10ec8d08 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 353681984
parent 316a2977
...@@ -160,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -160,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
embedded_inputs = self.embedding_lookup(sources) embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype) tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
embedded_inputs = tf.cast(embedded_inputs, self._dtype) embedded_inputs = tf.cast(embedded_inputs, self._dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation. # Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2) input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
attention_mask = tf.cast( attention_mask = tf.cast(
...@@ -243,8 +243,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -243,8 +243,8 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_inputs = self.embedding_lookup(targets) decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype) tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
decoder_inputs = tf.cast(decoder_inputs, self._dtype) decoder_inputs = tf.cast(decoder_inputs, self._dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment