Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1d690289
Unverified
Commit
1d690289
authored
May 26, 2020
by
ZhuBaohe
Committed by
GitHub
May 26, 2020
Browse files
fix (#4410)
parent
b86e42e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
13 deletions
+11
-13
src/transformers/modeling_tf_xlnet.py
src/transformers/modeling_tf_xlnet.py
+11
-13
No files found.
src/transformers/modeling_tf_xlnet.py
View file @
1d690289
...
...
@@ -586,8 +586,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if
data_mask
is
not
None
:
# all mems can be attended to
mems_mask
=
tf
.
zeros
([
shape_list
(
data_mask
)[
0
],
mlen
,
bsz
],
dtype
=
dtype_float
)
data_mask
=
tf
.
concat
([
mems_mask
,
data_mask
],
axis
=
1
)
if
mlen
>
0
:
mems_mask
=
tf
.
zeros
([
shape_list
(
data_mask
)[
0
],
mlen
,
bsz
],
dtype
=
dtype_float
)
data_mask
=
tf
.
concat
([
mems_mask
,
data_mask
],
axis
=
1
)
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
...
...
@@ -598,7 +599,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
dtype_float
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype_float
),
non_tgt_mask
],
axis
=-
1
)
if
mlen
>
0
:
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
dtype_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
dtype_float
)
else
:
non_tgt_mask
=
None
...
...
@@ -621,8 +623,11 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Segment embedding
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
cat_ids
=
tf
.
concat
([
mem_pad
,
token_type_ids
],
0
)
if
mlen
>
0
:
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
cat_ids
=
tf
.
concat
([
mem_pad
,
token_type_ids
],
0
)
else
:
cat_ids
=
token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
token_type_ids
[:,
None
],
cat_ids
[
None
,
:])),
tf
.
int32
)
...
...
@@ -640,14 +645,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
0
)
head_mask
=
head_mask
.
expand
(
self
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
1
).
unsqueeze
(
1
)
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
raise
NotImplementedError
else
:
head_mask
=
[
None
]
*
self
.
n_layer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment