"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "54369208ee5ec551009bdf747309db85958594cf"
Commit 638fe7f5 authored by Rémi Louf's avatar Rémi Louf
Browse files

correct composition of padding and causal masks

parent 4e0f2434
...@@ -288,8 +288,8 @@ class BertAttention(nn.Module): ...@@ -288,8 +288,8 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask) self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -350,7 +350,6 @@ class BertLayer(nn.Module): ...@@ -350,7 +350,6 @@ class BertLayer(nn.Module):
return outputs return outputs
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
...@@ -365,7 +364,8 @@ class BertEncoder(nn.Module): ...@@ -365,7 +364,8 @@ class BertEncoder(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask) encoder_hidden_state = encoder_hidden_states[i]
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_state, encoder_attention_mask)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
...@@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel): ...@@ -607,22 +607,26 @@ class BertModel(BertPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
""" Forward pass on the Model. """ Forward pass on the Model.
The values of the attention matrix (shape [batch_size, seq_length])
should be 1.0 for the position we want to attend to and 0. for the ones
we do not want to attend to.
The model can behave as an encoder (with only self-attention) as well The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1]. ever self-attention layer, following the architecture described in [1].
To behave like as a decoder the model needs to be initialized with the To behave like as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`. An `is_decoder` argument of the config set to `True`. An
`encoder_hidden_state` is expected as an input to the forward pass. `encoder_hidden_states` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify: When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to (1) Self-attention masks that need to be causal (only attends to
previous tokens); previous tokens);
(2) A cross-attention mask that prevents the module (2) A cross-attention mask that prevents the module
from attending to the encoder' padding tokens. from attending to the encoder's padding tokens.
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in [1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017. neural information processing systems. 2017.
...@@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel): ...@@ -632,20 +636,20 @@ class BertModel(BertPreTrainedModel):
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) token_type_ids = torch.zeros_like(input_ids)
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3: if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :] extended_attention_mask = attention_mask[:, None, :, :]
# provided a padding mask of dimensions [batch_size, seq_length] # Provided a padding mask of dimensions [batch_size, seq_length]
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if decoder, make it causal # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
if self.config.is_decoder: if self.config.is_decoder:
batch_size, seq_length = input_ids.size() batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length) seq_ids = torch.arange(seq_length)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
...@@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel): ...@@ -676,7 +680,7 @@ class BertModel(BertPreTrainedModel):
encoder_outputs = self.encoder(embedding_output, encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_state=encoder_hidden_state, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask) encoder_attention_mask=encoder_attention_mask)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment