Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a5857963
Commit
a5857963
authored
Aug 13, 2020
by
xinliupitt
Browse files
remove timing_signal cast
parent
8028eee4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
3 deletions
+6
-3
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+6
-3
No files found.
official/nlp/modeling/models/seq2seq_transformer.py
View file @
a5857963
...
...
@@ -394,7 +394,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens."""
timing_signal
=
self
.
position_embedding
(
inputs
=
None
,
length
=
max_decode_length
+
1
)
timing_signal
=
tf
.
cast
(
timing_signal
,
self
.
_dtype
)
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
max_decode_length
,
dtype
=
self
.
_dtype
)
...
...
@@ -541,7 +540,7 @@ class TransformerEncoder(tf.keras.layers.Layer):
super
(
TransformerEncoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
return
{
config
=
{
"num_layers"
:
self
.
_num_layers
,
"num_attention_heads"
:
...
...
@@ -563,6 +562,8 @@ class TransformerEncoder(tf.keras.layers.Layer):
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
base_config
=
super
(
TransformerEncoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
encoder_inputs
,
...
...
@@ -657,7 +658,7 @@ class TransformerDecoder(tf.keras.layers.Layer):
super
(
TransformerDecoder
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
return
{
config
=
{
"num_layers"
:
self
.
_num_layers
,
"num_attention_heads"
:
...
...
@@ -679,6 +680,8 @@ class TransformerDecoder(tf.keras.layers.Layer):
"intermediate_dropout"
:
self
.
_intermediate_dropout
}
base_config
=
super
(
TransformerDecoder
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
target
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment