Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
638fe7f5
Commit
638fe7f5
authored
Oct 17, 2019
by
Rémi Louf
Browse files
correct composition of padding and causal masks
parent
4e0f2434
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
14 deletions
+18
-14
transformers/modeling_bert.py
transformers/modeling_bert.py
+18
-14
No files found.
transformers/modeling_bert.py
View file @
638fe7f5
...
...
@@ -288,8 +288,8 @@ class BertAttention(nn.Module):
self
.
self
.
all_head_size
=
self
.
self
.
attention_head_size
*
self
.
self
.
num_attention_heads
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
,
encoder_hidden_state
s
=
None
,
encoder_attention_mask
=
None
):
self_outputs
=
self
.
self
(
hidden_states
,
attention_mask
,
head_mask
,
encoder_hidden_state
s
,
encoder_attention_mask
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
,
encoder_hidden_state
=
None
,
encoder_attention_mask
=
None
):
self_outputs
=
self
.
self
(
hidden_states
,
attention_mask
,
head_mask
,
encoder_hidden_state
,
encoder_attention_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
hidden_states
)
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -350,7 +350,6 @@ class BertLayer(nn.Module):
return
outputs
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
BertEncoder
,
self
).
__init__
()
...
...
@@ -365,7 +364,8 @@ class BertEncoder(nn.Module):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
],
encoder_hidden_states
,
encoder_attention_mask
)
encoder_hidden_state
=
encoder_hidden_states
[
i
]
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
],
encoder_hidden_state
,
encoder_attention_mask
)
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
...
...
@@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel):
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
encoder_hidden_state
=
None
,
encoder_attention_mask
=
None
):
head_mask
=
None
,
encoder_hidden_state
s
=
None
,
encoder_attention_mask
=
None
):
""" Forward pass on the Model.
The values of the attention matrix (shape [batch_size, seq_length])
should be 1.0 for the position we want to attend to and 0. for the ones
we do not want to attend to.
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1].
To behave like as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`. An
`encoder_hidden_state` is expected as an input to the forward pass.
`encoder_hidden_state
s
` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to
previous tokens);
(2) A cross-attention mask that prevents the module
from attending to the encoder' padding tokens.
from attending to the encoder'
s
padding tokens.
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017.
...
...
@@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel):
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
#
w
e
may want to provide a
mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just make it broadcastable to all heads.
#
W
e
can provide a self-attention
mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just
need to
make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
#
p
rovided a padding mask of dimensions [batch_size, seq_length]
# - if
encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if
decoder, make it causal
#
P
rovided a padding mask of dimensions [batch_size, seq_length]
# - if
the model is a decoder, apply a causal mask in addition to the padding mask
# - if
the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[
None
,
None
,
:,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[
:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
...
...
@@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel):
encoder_outputs
=
self
.
encoder
(
embedding_output
,
attention_mask
=
extended_attention_mask
,
head_mask
=
head_mask
,
encoder_hidden_state
=
encoder_hidden_state
,
encoder_hidden_state
s
=
encoder_hidden_state
s
,
encoder_attention_mask
=
encoder_attention_mask
)
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment