Commit 22b84f74 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Transformer Nits: reshape -> set_shape

PiperOrigin-RevId: 281872406
parent 48693cad
......@@ -121,6 +121,7 @@ class Transformer(tf.keras.Model):
if len(inputs) == 2:
inputs, targets = inputs[0], inputs[1]
else:
# Decoding path.
inputs, targets = inputs[0], None
if self.params["padded_decode"]:
if not self.params["num_replicas"]:
......@@ -128,8 +129,9 @@ class Transformer(tf.keras.Model):
"Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"])
inputs = tf.reshape(
inputs, [decode_batch_size, self.params["decode_max_length"]])
inputs.set_shape([
decode_batch_size, self.params["decode_max_length"]
])
# Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment