Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0b395f65
Commit
0b395f65
authored
Aug 08, 2020
by
xinliupitt
Browse files
remove training arg
parent
ef800b03
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
11 deletions
+3
-11
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+3
-11
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
0b395f65
...
...
@@ -98,7 +98,7 @@ class Seq2SeqTransformer(tf.keras.Model):
"params"
:
self
.
params
,
}
def
call
(
self
,
inputs
,
training
):
def
call
(
self
,
inputs
):
"""Calculate target logits or inferred target sequences.
Args:
...
...
@@ -162,10 +162,6 @@ class Seq2SeqTransformer(tf.keras.Model):
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"dtype"
])
encoder_inputs
=
embedded_inputs
+
pos_encoding
# if training:
# encoder_inputs = tf.nn.dropout(
# encoder_inputs, rate=self.params["layer_postprocess_dropout"])
encoder_inputs
=
self
.
encoder_dropout
(
encoder_inputs
)
encoder_outputs
=
self
.
encoder_layer
(
encoder_inputs
,
...
...
@@ -185,7 +181,7 @@ class Seq2SeqTransformer(tf.keras.Model):
self
.
params
[
"dtype"
])
symbols_to_logits_fn
=
self
.
_get_symbols_to_logits_fn
(
max_decode_length
,
training
)
max_decode_length
)
# Create initial set of IDs that will be passed to symbols_to_logits_fn.
initial_ids
=
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
int32
)
...
...
@@ -254,10 +250,6 @@ class Seq2SeqTransformer(tf.keras.Model):
pos_encoding
=
tf
.
cast
(
pos_encoding
,
self
.
params
[
"dtype"
])
decoder_inputs
+=
pos_encoding
# if training:
# decoder_inputs = tf.nn.dropout(
# decoder_inputs, rate=self.params["layer_postprocess_dropout"])
decoder_inputs
=
self
.
decoder_dropout
(
decoder_inputs
)
decoder_shape
=
tf_utils
.
get_shape_list
(
decoder_inputs
,
...
...
@@ -287,7 +279,7 @@ class Seq2SeqTransformer(tf.keras.Model):
return
logits
def
_get_symbols_to_logits_fn
(
self
,
max_decode_length
,
training
):
def
_get_symbols_to_logits_fn
(
self
,
max_decode_length
):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal
=
self
.
position_embedding
(
inputs
=
None
,
length
=
max_decode_length
+
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment