Commit 07520696 authored by Rémi Louf's avatar Rémi Louf
Browse files

adapt attention masks for the decoder case

The introduction of a decoder introduces 2 changes:
- We need to be able to specify a separate mask in the cross
attention to mask the positions corresponding to padding tokens in the
encoder state.
- The self-attention in the decoder needs to be causal on top of not
attending to padding tokens.
parent c5a94a61
......@@ -198,12 +198,16 @@ class BertSelfAttention(nn.Module):
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
mixed_query_layer = self.query(hidden_states)
# if the attention Module is a encoder-decoder self attention module
# they keys & values are given by the encoder; the attention mask
# needs to be such that there is no atention on the encoder's padding tokens.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
......@@ -284,8 +288,8 @@ class BertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_outputs = self.self(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
......@@ -330,13 +334,13 @@ class BertLayer(nn.Module):
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_state is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state, encoder_attention_mask)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
......@@ -346,6 +350,7 @@ class BertLayer(nn.Module):
return outputs
# NOTE I think we may need to call encoder_hidden_states[i] for each layer
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
......@@ -353,14 +358,14 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask)
hidden_states = layer_outputs[0]
if self.output_attentions:
......@@ -579,6 +584,7 @@ class BertModel(BertPreTrainedModel):
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
......@@ -601,18 +607,47 @@ class BertModel(BertPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_state=None):
head_mask=None, encoder_hidden_state=None, encoder_attention_mask=None):
""" Forward pass on the Model.
The model can behave as an encoder (with only self-attention) as well
as a decoder, in which case a layer of cross-attention is added between
ever self-attention layer, following the architecture described in [1].
To behave like as a decoder the model needs to be initialized with the
`is_decoder` argument of the config set to `True`. An
`encoder_hidden_state` is expected as an input to the forward pass.
When a decoder, there are two kinds of attention masks to specify:
(1) Self-attention masks that need to be causal (only attends to
previous tokens);
(2) A cross-attention mask that prevents the module
from attending to the encoder' padding tokens.
[1] Vaswani, Ashish, et al. "Attention is all you need." Advances in
neural information processing systems. 2017.
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# we may want to provide a mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just make it broadcastable to all heads.
if attention_mask.dims() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
# provided a padding mask of dimensions [batch_size, seq_length]
# - if encoder, make it broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if decoder, make it causal
if attention_mask.dims() == 2:
if self.config.is_decoder:
batch_size, seq_length = input_ids.size()
seq_ids = torch.arange(seq_length)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[None, None, :, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
......@@ -641,7 +676,8 @@ class BertModel(BertPreTrainedModel):
encoder_outputs = self.encoder(embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_state=encoder_hidden_state)
encoder_hidden_state=encoder_hidden_state,
encoder_attention_mask=encoder_attention_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment