Unverified Commit ecabbf6d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2107 from huggingface/encoder-mask-shape

create encoder attention mask from shape of hidden states
parents 1d189304 3520be78
...@@ -692,17 +692,19 @@ class BertModel(BertPreTrainedModel): ...@@ -692,17 +692,19 @@ class BertModel(BertPreTrainedModel):
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder: if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device) encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
if encoder_attention_mask.dim() == 3: if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
elif encoder_attention_mask.dim() == 2: elif encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
else: else:
raise ValueError("Wrong shape for input_ids (shape {}) or encoder_attention_mask (shape {})".format(input_shape, raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(encoder_hidden_shape,
encoder_attention_mask.shape)) encoder_attention_mask.shape))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment