Unverified Commit 63fbed5c authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Make create_extended_attention_mask_for_decoder static method (#16893)

parent fb0ae129
......@@ -593,7 +593,8 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
@staticmethod
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
......@@ -638,7 +639,7 @@ class ModuleUtilsMixin:
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
extended_attention_mask = self.create_extended_attention_mask_for_decoder(
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
input_shape, attention_mask, device
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment