Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6c03d4ac
Unverified
Commit
6c03d4ac
authored
Jan 04, 2021
by
Julien Plu
Committed by
GitHub
Jan 04, 2021
Browse files
Fix CTRL (#9291)
parent
c581d8af
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/transformers/models/ctrl/modeling_tf_ctrl.py
src/transformers/models/ctrl/modeling_tf_ctrl.py
+3
-3
No files found.
src/transformers/models/ctrl/modeling_tf_ctrl.py
View file @
6c03d4ac
...
...
@@ -375,7 +375,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states
=
()
if
inputs
[
"output_hidden_states"
]
else
None
all_attentions
=
()
if
inputs
[
"output_attentions"
]
else
None
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
inputs
[
"past"
])):
if
output_hidden_states
:
if
inputs
[
"
output_hidden_states
"
]
:
all_hidden_states
=
all_hidden_states
+
(
tf
.
reshape
(
hidden_states
,
output_shape
),)
outputs
=
h
(
hidden_states
,
...
...
@@ -384,7 +384,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs
[
"attention_mask"
],
inputs
[
"head_mask"
][
i
],
inputs
[
"use_cache"
],
output_attentions
,
inputs
[
"
output_attentions
"
]
,
training
=
inputs
[
"training"
],
)
hidden_states
,
present
=
outputs
[:
2
]
...
...
@@ -392,7 +392,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
if
inputs
[
"use_cache"
]:
presents
=
presents
+
(
present
,)
if
output_attentions
:
if
inputs
[
"
output_attentions
"
]
:
all_attentions
=
all_attentions
+
(
outputs
[
2
],)
hidden_states
=
self
.
layernorm
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment