Commit 0b395f65 authored by xinliupitt's avatar xinliupitt
Browse files

remove training arg

parent ef800b03
......@@ -98,7 +98,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"params": self.params,
}
def call(self, inputs, training):
def call(self, inputs):
"""Calculate target logits or inferred target sequences.
Args:
......@@ -162,10 +162,6 @@ class Seq2SeqTransformer(tf.keras.Model):
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
encoder_inputs = embedded_inputs + pos_encoding
# if training:
# encoder_inputs = tf.nn.dropout(
# encoder_inputs, rate=self.params["layer_postprocess_dropout"])
encoder_inputs = self.encoder_dropout(encoder_inputs)
encoder_outputs = self.encoder_layer(encoder_inputs,
......@@ -185,7 +181,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self.params["dtype"])
symbols_to_logits_fn = self._get_symbols_to_logits_fn(
max_decode_length, training)
max_decode_length)
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids = tf.zeros([batch_size], dtype=tf.int32)
......@@ -254,10 +250,6 @@ class Seq2SeqTransformer(tf.keras.Model):
pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
decoder_inputs += pos_encoding
# if training:
# decoder_inputs = tf.nn.dropout(
# decoder_inputs, rate=self.params["layer_postprocess_dropout"])
decoder_inputs = self.decoder_dropout(decoder_inputs)
decoder_shape = tf_utils.get_shape_list(decoder_inputs,
......@@ -287,7 +279,7 @@ class Seq2SeqTransformer(tf.keras.Model):
return logits
def _get_symbols_to_logits_fn(self, max_decode_length, training):
def _get_symbols_to_logits_fn(self, max_decode_length):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment