Commit a5857963 authored by xinliupitt's avatar xinliupitt
Browse files

remove timing_signal cast

parent 8028eee4
......@@ -394,7 +394,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self._dtype)
......@@ -541,7 +540,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
super(TransformerEncoder, self).build(input_shape)
def get_config(self):
return {
config = {
"num_layers":
self._num_layers,
"num_attention_heads":
......@@ -563,6 +562,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
"intermediate_dropout":
self._intermediate_dropout
}
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
encoder_inputs,
......@@ -657,7 +658,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
super(TransformerDecoder, self).build(input_shape)
def get_config(self):
return {
config = {
"num_layers":
self._num_layers,
"num_attention_heads":
......@@ -679,6 +680,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
"intermediate_dropout":
self._intermediate_dropout
}
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
target,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment