Unverified Commit 39f8eafc authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Remove device parameter from create_extended_attention_mask_for_decoder (#16894)

parent dd739f70
...@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module): ...@@ -137,7 +137,7 @@ class RetrievalQAEmbedder(nn.Module):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.sent_encoder.config.num_hidden_layers head_mask = [None] * self.sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device attention_mask, input_shape
) )
# define function for checkpointing # define function for checkpointing
......
...@@ -651,7 +651,13 @@ class ModuleUtilsMixin: ...@@ -651,7 +651,13 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask return encoder_extended_attention_mask
@staticmethod @staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device): def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
...@@ -672,7 +678,9 @@ class ModuleUtilsMixin: ...@@ -672,7 +678,9 @@ class ModuleUtilsMixin:
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask return extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
) -> Tensor:
""" """
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
...@@ -681,12 +689,16 @@ class ModuleUtilsMixin: ...@@ -681,12 +689,16 @@ class ModuleUtilsMixin:
Mask with ones indicating tokens to attend to, zeros for tokens to ignore. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`): input_shape (`Tuple[int]`):
The shape of the input to the model. The shape of the input to the model.
device: (`torch.device`):
The device of the input to the model.
Returns: Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
""" """
if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3: if attention_mask.dim() == 3:
......
...@@ -982,7 +982,7 @@ class BertModel(BertPreTrainedModel): ...@@ -982,7 +982,7 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -364,9 +364,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -364,9 +364,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = None extended_attention_mask = None
if not use_cache: if not use_cache:
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask, input_shape, device
)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -2112,9 +2112,7 @@ class BigBirdModel(BigBirdPreTrainedModel): ...@@ -2112,9 +2112,7 @@ class BigBirdModel(BigBirdPreTrainedModel):
to_mask = None to_mask = None
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask, input_shape, device
)
else: else:
raise ValueError( raise ValueError(
f"attention_type can either be original_full or block_sparse, but is {self.attention_type}" f"attention_type can either be original_full or block_sparse, but is {self.attention_type}"
......
...@@ -1130,12 +1130,12 @@ class CanineModel(CaninePreTrainedModel): ...@@ -1130,12 +1130,12 @@ class CanineModel(CaninePreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
molecule_attention_mask = self._downsample_attention_mask( molecule_attention_mask = self._downsample_attention_mask(
attention_mask, downsampling_rate=self.config.downsampling_rate attention_mask, downsampling_rate=self.config.downsampling_rate
) )
extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_molecule_attention_mask: torch.Tensor = self.get_extended_attention_mask(
molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1]), device molecule_attention_mask, (batch_size, molecule_attention_mask.shape[-1])
) )
# Prepare head mask if needed # Prepare head mask if needed
......
...@@ -833,7 +833,7 @@ class ConvBertModel(ConvBertPreTrainedModel): ...@@ -833,7 +833,7 @@ class ConvBertModel(ConvBertPreTrainedModel):
else: else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.embeddings( hidden_states = self.embeddings(
......
...@@ -820,7 +820,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel): ...@@ -820,7 +820,7 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -882,7 +882,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -882,7 +882,7 @@ class ElectraModel(ElectraPreTrainedModel):
else: else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -814,7 +814,7 @@ class IBertModel(IBertPreTrainedModel): ...@@ -814,7 +814,7 @@ class IBertModel(IBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -1692,7 +1692,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1692,7 +1692,7 @@ class LongformerModel(LongformerPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)[ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)[
:, 0, 0, : :, 0, 0, :
] ]
......
...@@ -940,7 +940,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel): ...@@ -940,7 +940,7 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -268,7 +268,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin): ...@@ -268,7 +268,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
[torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1 [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
) )
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
......
...@@ -875,9 +875,7 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -875,9 +875,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask, input_shape, self.device
)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -547,7 +547,7 @@ class MPNetModel(MPNetPreTrainedModel): ...@@ -547,7 +547,7 @@ class MPNetModel(MPNetPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device) attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
......
...@@ -624,7 +624,7 @@ class NystromformerModel(NystromformerPreTrainedModel): ...@@ -624,7 +624,7 @@ class NystromformerModel(NystromformerPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
......
...@@ -952,7 +952,7 @@ class QDQBertModel(QDQBertPreTrainedModel): ...@@ -952,7 +952,7 @@ class QDQBertModel(QDQBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -1078,7 +1078,7 @@ class RealmBertModel(RealmPreTrainedModel): ...@@ -1078,7 +1078,7 @@ class RealmBertModel(RealmPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -857,7 +857,7 @@ class RemBertModel(RemBertPreTrainedModel): ...@@ -857,7 +857,7 @@ class RemBertModel(RemBertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
......
...@@ -117,7 +117,7 @@ class RetriBertModel(RetriBertPreTrainedModel): ...@@ -117,7 +117,7 @@ class RetriBertModel(RetriBertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * sent_encoder.config.num_hidden_layers head_mask = [None] * sent_encoder.config.num_hidden_layers
extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask( extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
attention_mask, input_shape, device attention_mask, input_shape
) )
# define function for checkpointing # define function for checkpointing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment