Commit f6557386 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411729044
parent 65c81380
...@@ -263,8 +263,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -263,8 +263,6 @@ class Seq2SeqTransformer(tf.keras.Model):
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1] targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
decoder_inputs = self.embedding_lookup(targets) decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
...@@ -325,11 +323,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -325,11 +323,7 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_input = ids[:, -1:] decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal. # Preprocess decoder input by getting embeddings and adding timing signal.
source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input) decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i] decoder_input += timing_signal[i]
if self._padded_decode: if self._padded_decode:
# indexing does not work on TPU. # indexing does not work on TPU.
......
...@@ -76,8 +76,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer): ...@@ -76,8 +76,8 @@ class EmbeddingSharedWeights(tf.keras.layers.Layer):
with tf.name_scope("embedding"): with tf.name_scope("embedding"):
# Create binary mask of size [batch_size, length] # Create binary mask of size [batch_size, length]
embeddings = tf.gather(self.shared_weights, inputs) embeddings = tf.gather(self.shared_weights, inputs)
mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype) # mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
embeddings *= tf.expand_dims(mask, -1) # embeddings *= tf.expand_dims(mask, -1)
# Scale embedding by the sqrt of the hidden size # Scale embedding by the sqrt of the hidden size
embeddings *= self.hidden_size**0.5 embeddings *= self.hidden_size**0.5
......
...@@ -196,13 +196,12 @@ class Transformer(tf.keras.Model): ...@@ -196,13 +196,12 @@ class Transformer(tf.keras.Model):
with tf.name_scope("decode"): with tf.name_scope("decode"):
# Prepare inputs to decoder layers by shifting targets, adding positional # Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout. # encoding and applying dropout.
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
decoder_inputs = self.embedding_softmax_layer(targets) decoder_inputs = self.embedding_softmax_layer(targets)
decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"]) decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"])
attention_bias = tf.cast(attention_bias, self.params["dtype"]) attention_bias = tf.cast(attention_bias, self.params["dtype"])
with tf.name_scope("shift_targets"):
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs,
[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
with tf.name_scope("add_pos_encoding"): with tf.name_scope("add_pos_encoding"):
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = self.position_embedding(decoder_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment