Commit 3520be78 authored by Rémi Louf's avatar Rémi Louf
Browse files

create encoder attention mask from shape of hidden states

We currently create encoder attention masks (when they're not provided)
based on the shape of the inputs to the encoder. This is obviously
wrong; sequences can be of different lengths. We now create the encoder
attention mask based on the batch_size and sequence_length of the
encoder hidden states.
parent 0cb16386
...@@ -691,17 +691,19 @@ class BertModel(BertPreTrainedModel): ...@@ -691,17 +691,19 @@ class BertModel(BertPreTrainedModel):
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder: if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if encoder_attention_mask.dim() == 3: if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
elif encoder_attention_mask.dim() == 2: elif encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else: else:
raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape, raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
encoder_attention_mask.shape)) encoder_attention_mask.shape))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment