Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45de034b
"...composable_kernel_onnx.git" did not exist on "d09ea4f4e5aca0aec89badff827639d998ee1f0b"
Commit
45de034b
authored
Sep 17, 2019
by
thomwolf
Browse files
fix #1223
parent
c88f0516
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
5 deletions
+10
-5
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+10
-5
No files found.
pytorch_transformers/modeling_xlnet.py
View file @
45de034b
...
@@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -743,8 +743,9 @@ class XLNetModel(XLNetPreTrainedModel):
if
data_mask
is
not
None
:
if
data_mask
is
not
None
:
# all mems can be attended to
# all mems can be attended to
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]).
to
(
data_mask
)
if
mlen
>
0
:
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
mems_mask
=
torch
.
zeros
([
data_mask
.
shape
[
0
],
mlen
,
bsz
]).
to
(
data_mask
)
data_mask
=
torch
.
cat
([
mems_mask
,
data_mask
],
dim
=
1
)
if
attn_mask
is
None
:
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
attn_mask
=
data_mask
[:,
:,
:,
None
]
else
:
else
:
...
@@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -755,7 +756,8 @@ class XLNetModel(XLNetPreTrainedModel):
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
-
torch
.
eye
(
qlen
).
to
(
attn_mask
)
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
if
mlen
>
0
:
non_tgt_mask
=
torch
.
cat
([
torch
.
zeros
([
qlen
,
mlen
]).
to
(
attn_mask
),
non_tgt_mask
],
dim
=-
1
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
non_tgt_mask
=
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
).
to
(
attn_mask
)
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
...
@@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -775,8 +777,11 @@ class XLNetModel(XLNetPreTrainedModel):
##### Segment embedding
##### Segment embedding
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
if
mlen
>
0
:
cat_ids
=
torch
.
cat
([
mem_pad
,
token_type_ids
],
dim
=
0
)
mem_pad
=
torch
.
zeros
([
mlen
,
bsz
],
dtype
=
torch
.
long
,
device
=
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
token_type_ids
],
dim
=
0
)
else
:
cat_ids
=
token_type_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
(
token_type_ids
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
seg_mat
=
(
token_type_ids
[:,
None
]
!=
cat_ids
[
None
,
:]).
long
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment