Commit a5857963 authored by xinliupitt's avatar xinliupitt
Browse files

remove timing_signal cast

parent 8028eee4
...@@ -394,7 +394,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -394,7 +394,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens.""" """Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding( timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1) inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
max_decode_length, dtype=self._dtype) max_decode_length, dtype=self._dtype)
...@@ -541,7 +540,7 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -541,7 +540,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
super(TransformerEncoder, self).build(input_shape) super(TransformerEncoder, self).build(input_shape)
def get_config(self): def get_config(self):
return { config = {
"num_layers": "num_layers":
self._num_layers, self._num_layers,
"num_attention_heads": "num_attention_heads":
...@@ -563,6 +562,8 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -563,6 +562,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
"intermediate_dropout": "intermediate_dropout":
self._intermediate_dropout self._intermediate_dropout
} }
base_config = super(TransformerEncoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(self,
encoder_inputs, encoder_inputs,
...@@ -657,7 +658,7 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -657,7 +658,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
super(TransformerDecoder, self).build(input_shape) super(TransformerDecoder, self).build(input_shape)
def get_config(self): def get_config(self):
return { config = {
"num_layers": "num_layers":
self._num_layers, self._num_layers,
"num_attention_heads": "num_attention_heads":
...@@ -679,6 +680,8 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -679,6 +680,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
"intermediate_dropout": "intermediate_dropout":
self._intermediate_dropout self._intermediate_dropout
} }
base_config = super(TransformerDecoder, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, def call(self,
target, target,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment