Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
842f3bf0
Commit
842f3bf0
authored
Oct 30, 2019
by
Timothy Liu
Browse files
Fixed training for TF XLM
parent
079bfb32
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
4 deletions
+8
-4
transformers/modeling_tf_xlm.py
transformers/modeling_tf_xlm.py
+8
-4
No files found.
transformers/modeling_tf_xlm.py
View file @
842f3bf0
...
...
@@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
attn_mask
=
mask
# sanity check
assert
shape_list
(
mask
)
==
[
bs
,
slen
]
# assert shape_list(mask) == [bs, slen]
tf
.
debugging
.
assert_equal
(
shape_list
(
mask
),
[
bs
,
slen
])
assert
causal
is
False
or
shape_list
(
attn_mask
)
==
[
bs
,
slen
,
slen
]
mask
=
tf
.
cast
(
mask
,
dtype
=
dtype
)
...
...
@@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs
bs
,
slen
=
shape_list
(
input_ids
)
assert
shape_list
(
lengths
)[
0
]
==
bs
# assert shape_list(lengths)[0] == bs
tf
.
debugging
.
assert_equal
(
shape_list
(
lengths
)[
0
],
bs
)
# assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
...
...
@@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if
position_ids
is
None
:
position_ids
=
tf
.
expand_dims
(
tf
.
range
(
slen
),
axis
=
0
)
else
:
assert
shape_list
(
position_ids
)
==
[
bs
,
slen
]
# (slen, bs)
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
shape_list
(
position_ids
),
[
bs
,
slen
])
# position_ids = position_ids.transpose(0, 1)
# langs
if
langs
is
not
None
:
assert
shape_list
(
langs
)
==
[
bs
,
slen
]
# (slen, bs)
# assert shape_list(langs) == [bs, slen] # (slen, bs)
tf
.
debugging
.
assert_equal
(
shape_list
(
langs
),
[
bs
,
slen
])
# langs = langs.transpose(0, 1)
# Prepare head mask if needed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment