Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
transformer-xl_pytorch
Commits
dc8094a6
Commit
dc8094a6
authored
Aug 20, 2024
by
yongshk
Browse files
Initial commit
parent
01153e93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
pytorch/mem_transformer.py
pytorch/mem_transformer.py
+10
-2
No files found.
pytorch/mem_transformer.py
View file @
dc8094a6
...
@@ -667,11 +667,19 @@ class MemTransformerLM(nn.Module):
...
@@ -667,11 +667,19 @@ class MemTransformerLM(nn.Module):
mask_shift_len
=
qlen
-
mask_len
mask_shift_len
=
qlen
-
mask_len
else
:
else
:
mask_shift_len
=
qlen
mask_shift_len
=
qlen
# 在 PyTorch 2.x 中,Byte 类型被废弃,不再作为 mask 的类型。取而代之的是 boolean 类型,这样可以更好地支持布尔运算。
# dec_attn_mask = (torch.triu(all_ones, 1+mlen)
# + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
dec_attn_mask
=
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
)).
b
yte
()[:,
:,
None
]
# -1
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
)).
b
ool
()[:,
:,
None
]
# -1
else
:
else
:
# 在 PyTorch 2.x 中,Byte 类型被废弃,不再作为 mask 的类型。取而代之的是 boolean 类型,这样可以更好地支持布尔运算。
# dec_attn_mask = torch.triu(
# word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
dec_attn_mask
=
torch
.
triu
(
dec_attn_mask
=
torch
.
triu
(
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
byte
()[:,:,
None
]
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
1
+
mlen
).
bool
()[:,:,
None
]
hids
=
[]
hids
=
[]
if
self
.
attn_type
==
0
:
# default
if
self
.
attn_type
==
0
:
# default
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment