Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
49c9895b
Commit
49c9895b
authored
Aug 11, 2022
by
Guolin Ke
Browse files
fix bug in return future mask
parent
689e0b24
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
15 deletions
+23
-15
unicore/modules/transformer_decoder.py
unicore/modules/transformer_decoder.py
+22
-14
unicore/modules/transformer_encoder.py
unicore/modules/transformer_encoder.py
+1
-1
No files found.
unicore/modules/transformer_decoder.py
View file @
49c9895b
...
@@ -11,14 +11,17 @@ import torch.nn.functional as F
...
@@ -11,14 +11,17 @@ import torch.nn.functional as F
from
.
import
TransformerDecoderLayer
,
LayerNorm
from
.
import
TransformerDecoderLayer
,
LayerNorm
from
.transformer_encoder
import
relative_position_bucket
from
.transformer_encoder
import
relative_position_bucket
def
fill_with_neg_inf
(
t
):
def
fill_with_neg_inf
(
t
):
return
t
.
fill_
(
float
(
"-inf"
))
return
t
.
fill_
(
float
(
"-inf"
))
def
bulid_future_mask
(
seq_len
):
def
bulid_future_mask
(
seq_len
):
return
torch
.
triu
(
return
torch
.
triu
(
fill_with_neg_inf
(
torch
.
zeros
([
seq_len
,
seq_len
])),
1
fill_with_neg_inf
(
torch
.
zeros
([
seq_len
,
seq_len
])),
1
)
)
class
TransformerDecoder
(
nn
.
Module
):
class
TransformerDecoder
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -66,7 +69,7 @@ class TransformerDecoder(nn.Module):
...
@@ -66,7 +69,7 @@ class TransformerDecoder(nn.Module):
activation_dropout
=
activation_dropout
,
activation_dropout
=
activation_dropout
,
activation_fn
=
activation_fn
,
activation_fn
=
activation_fn
,
post_ln
=
post_ln
,
post_ln
=
post_ln
,
)
)
for
_
in
range
(
decoder_layers
)
for
_
in
range
(
decoder_layers
)
]
]
...
@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module):
...
@@ -77,7 +80,8 @@ class TransformerDecoder(nn.Module):
assert
rel_pos_bins
%
2
==
0
assert
rel_pos_bins
%
2
==
0
self
.
rel_pos_bins
=
rel_pos_bins
self
.
rel_pos_bins
=
rel_pos_bins
self
.
max_rel_pos
=
max_rel_pos
self
.
max_rel_pos
=
max_rel_pos
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
rel_pos_bins
,
self
.
attention_heads
)
self
.
relative_attention_bias
=
nn
.
Embedding
(
self
.
rel_pos_bins
,
self
.
attention_heads
)
seq_len
=
self
.
max_seq_len
seq_len
=
self
.
max_seq_len
context_position
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
)[:,
None
]
context_position
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
)[:,
None
]
memory_position
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
)[
None
,
:]
memory_position
=
torch
.
arange
(
seq_len
,
dtype
=
torch
.
long
)[
None
,
:]
...
@@ -98,7 +102,7 @@ class TransformerDecoder(nn.Module):
...
@@ -98,7 +102,7 @@ class TransformerDecoder(nn.Module):
values
=
F
.
embedding
(
rp_bucket
,
self
.
relative_attention_bias
.
weight
)
values
=
F
.
embedding
(
rp_bucket
,
self
.
relative_attention_bias
.
weight
)
values
=
values
.
permute
([
2
,
0
,
1
])
values
=
values
.
permute
([
2
,
0
,
1
])
return
values
.
contiguous
()
return
values
.
contiguous
()
def
get_future_mask
(
self
,
x
,
attn_mask
):
def
get_future_mask
(
self
,
x
,
attn_mask
):
if
not
self
.
auto_regressive
:
if
not
self
.
auto_regressive
:
return
attn_mask
return
attn_mask
...
@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module):
...
@@ -108,9 +112,12 @@ class TransformerDecoder(nn.Module):
self
.
_future_mask
=
self
.
_future_mask
.
type_as
(
x
)
self
.
_future_mask
=
self
.
_future_mask
.
type_as
(
x
)
if
attn_mask
is
None
:
if
attn_mask
is
None
:
ret
=
self
.
_future_mask
[:
x
.
size
(
1
),
:
x
.
size
(
1
)]
ret
=
self
.
_future_mask
[:
x
.
size
(
1
),
:
x
.
size
(
1
)]
ret
=
ret
.
contiguous
().
unsqueeze
(
0
).
repeat
(
x
.
size
(
0
)
*
self
.
attention_heads
,
1
,
1
)
ret
=
ret
.
contiguous
().
unsqueeze
(
0
).
repeat
(
x
.
size
(
0
)
*
self
.
attention_heads
,
1
,
1
)
return
ret
else
:
else
:
assert
list
(
attn_mask
.
size
())
==
[
x
.
size
(
0
)
*
self
.
attention_heads
,
x
.
size
(
1
),
x
.
size
(
1
)]
assert
list
(
attn_mask
.
size
())
==
[
x
.
size
(
0
)
*
self
.
attention_heads
,
x
.
size
(
1
),
x
.
size
(
1
)]
return
attn_mask
+
self
.
_future_mask
[:
x
.
size
(
1
),
:
x
.
size
(
1
)]
return
attn_mask
+
self
.
_future_mask
[:
x
.
size
(
1
),
:
x
.
size
(
1
)]
def
forward
(
def
forward
(
...
@@ -122,16 +129,17 @@ class TransformerDecoder(nn.Module):
...
@@ -122,16 +129,17 @@ class TransformerDecoder(nn.Module):
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
seq_len
=
emb
.
size
(
1
)
seq_len
=
emb
.
size
(
1
)
x
=
self
.
emb_layer_norm
(
emb
)
x
=
self
.
emb_layer_norm
(
emb
)
x
=
F
.
dropout
(
x
,
p
=
self
.
emb_dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
emb_dropout
,
training
=
self
.
training
)
# account for padding while computing the representation
# account for padding while computing the representation
if
padding_mask
is
not
None
:
if
padding_mask
is
not
None
:
x
=
x
*
(
1
-
padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
x
=
x
*
(
1
-
padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
rel_pos_bias
=
self
.
get_rel_pos_bias
(
x
).
repeat
(
x
.
size
(
0
),
1
,
1
)
if
self
.
rel_pos
else
None
rel_pos_bias
=
self
.
get_rel_pos_bias
(
x
).
repeat
(
x
.
size
(
0
),
1
,
1
)
if
self
.
rel_pos
else
None
if
attn_mask
is
None
:
if
attn_mask
is
None
:
attn_mask
=
rel_pos_bias
attn_mask
=
rel_pos_bias
...
@@ -150,12 +158,12 @@ class TransformerDecoder(nn.Module):
...
@@ -150,12 +158,12 @@ class TransformerDecoder(nn.Module):
)
)
attn_mask
=
attn_mask
.
view
(
-
1
,
seq_len
,
seq_len
)
attn_mask
=
attn_mask
.
view
(
-
1
,
seq_len
,
seq_len
)
padding_mask
=
None
padding_mask
=
None
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
encoder_out
=
encoder_out
,
padding_mask
=
padding_mask
,
attn_bias
=
attn_mask
,
x
=
layer
(
x
,
encoder_out
=
encoder_out
,
padding_mask
=
padding_mask
,
attn_bias
=
attn_mask
,
encoder_padding_mask
=
encoder_padding_mask
,
encoder_attn_bias
=
encoder_attn_mask
)
encoder_padding_mask
=
encoder_padding_mask
,
encoder_attn_bias
=
encoder_attn_mask
)
if
self
.
final_layer_norm
!=
None
:
if
self
.
final_layer_norm
is
not
None
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
final_layer_norm
(
x
)
return
x
return
x
unicore/modules/transformer_encoder.py
View file @
49c9895b
...
@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module):
...
@@ -157,7 +157,7 @@ class TransformerEncoder(nn.Module):
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
padding_mask
=
padding_mask
,
attn_bias
=
attn_mask
)
x
=
layer
(
x
,
padding_mask
=
padding_mask
,
attn_bias
=
attn_mask
)
if
self
.
final_layer_norm
!=
None
:
if
self
.
final_layer_norm
is
not
None
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
final_layer_norm
(
x
)
return
x
return
x
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment