Unverified Commit 63fbed5c authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Make create_extended_attention_mask_for_decoder static method (#16893)

parent fb0ae129
...@@ -593,7 +593,8 @@ class ModuleUtilsMixin: ...@@ -593,7 +593,8 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask return encoder_extended_attention_mask
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device): @staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device) seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
...@@ -638,7 +639,7 @@ class ModuleUtilsMixin: ...@@ -638,7 +639,7 @@ class ModuleUtilsMixin:
# - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder: if self.config.is_decoder:
extended_attention_mask = self.create_extended_attention_mask_for_decoder( extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device input_shape, attention_mask, device
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment