Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a2c8e516
Commit
a2c8e516
authored
Mar 09, 2020
by
Patrick von Platen
Browse files
fix torch to tf translation
parent
ca2047bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
1 deletion
+4
-1
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+4
-1
No files found.
src/transformers/modeling_tf_utils.py
View file @
a2c8e516
...
...
@@ -641,7 +641,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if
(
attention_mask
is
None
)
and
(
pad_token_id
is
not
None
)
and
(
pad_token_id
in
input_ids
):
import
ipdb
ipdb
.
set_trace
()
if
(
attention_mask
is
None
)
and
(
pad_token_id
is
not
None
)
and
(
pad_token_id
in
input_ids
.
numpy
()):
attention_mask
=
tf
.
cast
(
tf
.
math
.
not_equal
(
input_ids
,
pad_token_id
),
dtype
=
tf
.
int32
)
elif
attention_mask
is
None
:
attention_mask
=
tf
.
ones_like
(
input_ids
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment