Unverified Commit 91caf246 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1770 from huggingface/initi-encoder-mask

Only init encoder_attention_mask if stack is decoder
parents 49a69d5b cd286c21
...@@ -660,8 +660,6 @@ class BertModel(BertPreTrainedModel): ...@@ -660,8 +660,6 @@ class BertModel(BertPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
...@@ -692,6 +690,10 @@ class BertModel(BertPreTrainedModel): ...@@ -692,6 +690,10 @@ class BertModel(BertPreTrainedModel):
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(input_shape, device=device)
if encoder_attention_mask.dim() == 3: if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2: if encoder_attention_mask.dim() == 2:
...@@ -699,6 +701,8 @@ class BertModel(BertPreTrainedModel): ...@@ -699,6 +701,8 @@ class BertModel(BertPreTrainedModel):
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment