Commit 19e3d234 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 411088753
parent d1ed379e
......@@ -260,11 +260,11 @@ class Seq2SeqTransformer(tf.keras.Model):
return {"outputs": top_decoded_ids, "scores": top_scores}
# Shift targets to the right, and remove the last element
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
......@@ -325,7 +325,6 @@ class Seq2SeqTransformer(tf.keras.Model):
decoder_input = ids[:, -1:]
# Preprocess decoder input by getting embeddings and adding timing signal.
# decoder_input = self.embedding_softmax_layer(decoder_input)
source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment