Unverified Commit ac589375 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Attention Mask] Refactor all encoder-decoder attention mask (#27086)



* [FA2 Bart] Add FA2 to all Bart-like

* better

* Refactor attention mask

* remove all customized atteniton logic

* format

* mass rename

* replace _expand_mask

* replace _expand_mask

* mass rename

* add pt files

* mass replace & rename

* mass replace & rename

* mass replace & rename

* mass replace & rename

* Update src/transformers/models/idefics/modeling_idefics.py

* fix more

* clean more

* fix more

* make style

* fix again

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* small fix mistral

* finish

* finish

* finish

* finish

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 29c74f58
...@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss ...@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -265,39 +266,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -265,39 +266,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor): def _compute_new_attention_mask(hidden_states: torch.Tensor, seq_lens: torch.Tensor):
""" """
Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that Computes an attention mask of the form `(batch, seq_len)` with an attention for each element in the batch that
...@@ -1002,7 +970,7 @@ class SeamlessM4TConformerAdapterLayer(nn.Module): ...@@ -1002,7 +970,7 @@ class SeamlessM4TConformerAdapterLayer(nn.Module):
hidden_states.device hidden_states.device
) )
attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths) attention_mask = _compute_new_attention_mask(hidden_states=hidden_states, seq_lens=sub_sampled_lengths)
attention_mask = _expand_mask( attention_mask = _prepare_4d_attention_mask(
attention_mask, attention_mask,
hidden_states.dtype, hidden_states.dtype,
) )
...@@ -1819,7 +1787,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): ...@@ -1819,7 +1787,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1934,30 +1902,6 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): ...@@ -1934,30 +1902,6 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -2064,14 +2008,16 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): ...@@ -2064,14 +2008,16 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length=past_key_values_length) positions = self.embed_positions(input, past_key_values_length=past_key_values_length)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch Speech2Text model.""" """ PyTorch Speech2Text model."""
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -23,6 +22,7 @@ from torch import nn ...@@ -23,6 +22,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -62,39 +62,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -62,39 +62,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class Conv1dSubsampler(nn.Module): class Conv1dSubsampler(nn.Module):
""" """
Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
...@@ -788,7 +755,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -788,7 +755,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -883,27 +850,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -883,27 +850,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1008,14 +954,16 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1008,14 +954,16 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, logging, replace_return_docstrings from ...utils import add_start_docstrings, logging, replace_return_docstrings
...@@ -42,39 +43,6 @@ SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -42,39 +43,6 @@ SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2 # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->Speech2Text2
class Speech2Text2SinusoidalPositionalEmbedding(nn.Module): class Speech2Text2SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -492,27 +460,6 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): ...@@ -492,27 +460,6 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -617,14 +564,16 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): ...@@ -617,14 +564,16 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
......
...@@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss ...@@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -89,39 +90,6 @@ def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int = ...@@ -89,39 +90,6 @@ def shift_spectrograms_right(input_values: torch.Tensor, reduction_factor: int =
return shifted_input_values return shifted_input_values
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
...@@ -1342,7 +1310,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): ...@@ -1342,7 +1310,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -1539,30 +1507,6 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): ...@@ -1539,30 +1507,6 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
...@@ -1647,14 +1591,16 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): ...@@ -1647,14 +1591,16 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, hidden_states, past_key_values_length attention_mask, input_shape, hidden_states, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
......
...@@ -23,6 +23,7 @@ import torch ...@@ -23,6 +23,7 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -345,20 +346,6 @@ class TableTransformerConvModel(nn.Module): ...@@ -345,20 +346,6 @@ class TableTransformerConvModel(nn.Module):
return out, pos return out, pos
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
# Copied from transformers.models.detr.modeling_detr.DetrSinePositionEmbedding with Detr->TableTransformer # Copied from transformers.models.detr.modeling_detr.DetrSinePositionEmbedding with Detr->TableTransformer
class TableTransformerSinePositionEmbedding(nn.Module): class TableTransformerSinePositionEmbedding(nn.Module):
""" """
...@@ -966,7 +953,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel): ...@@ -966,7 +953,7 @@ class TableTransformerEncoder(TableTransformerPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1116,15 +1103,15 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): ...@@ -1116,15 +1103,15 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
if attention_mask is not None and combined_attention_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask(
attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
encoder_attention_mask = _expand_mask( encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, target_len=input_shape[-1] encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
) )
# optional intermediate hidden states # optional intermediate hidden states
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -225,39 +226,6 @@ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] ...@@ -225,39 +226,6 @@ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor]
return input_tensor.mean(dim=dim) return input_tensor.mean(dim=dim)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->TimeSeries # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->TimeSeries
class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding): class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -915,7 +883,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): ...@@ -915,7 +883,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -999,29 +967,6 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): ...@@ -999,29 +967,6 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
...@@ -1105,14 +1050,16 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): ...@@ -1105,14 +1050,16 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
# past_key_values_length # past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
hidden_states = self.value_embedding(inputs_embeds) hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length)
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, logging, replace_return_docstrings from ...utils import add_start_docstrings, logging, replace_return_docstrings
...@@ -42,39 +43,6 @@ TROCR_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -42,39 +43,6 @@ TROCR_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->TrOCR
class TrOCRLearnedPositionalEmbedding(nn.Embedding): class TrOCRLearnedPositionalEmbedding(nn.Embedding):
""" """
...@@ -515,27 +483,6 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -515,27 +483,6 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -655,14 +602,16 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -655,14 +602,16 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
input_shape = input.shape input_shape = input.shape
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
ModelOutput, ModelOutput,
...@@ -113,21 +114,6 @@ class VitsTextEncoderOutput(ModelOutput): ...@@ -113,21 +114,6 @@ class VitsTextEncoderOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
@torch.jit.script @torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
in_act = input_a + input_b in_act = input_a + input_b
...@@ -1150,7 +1136,7 @@ class VitsEncoder(nn.Module): ...@@ -1150,7 +1136,7 @@ class VitsEncoder(nn.Module):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
hidden_states = hidden_states * padding_mask hidden_states = hidden_states * padding_mask
......
...@@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss ...@@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -84,39 +85,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -84,39 +85,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices( def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
...@@ -1003,28 +971,6 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1003,28 +971,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1120,7 +1066,7 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1120,7 +1066,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
......
...@@ -24,6 +24,7 @@ import torch.utils.checkpoint ...@@ -24,6 +24,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -46,21 +47,6 @@ XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -46,21 +47,6 @@ XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# contrastive loss function, adapted from # contrastive loss function, adapted from
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
...@@ -729,24 +715,6 @@ class XCLIPEncoder(nn.Module): ...@@ -729,24 +715,6 @@ class XCLIPEncoder(nn.Module):
) )
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class XCLIPTextTransformer(nn.Module): class XCLIPTextTransformer(nn.Module):
def __init__(self, config: XCLIPTextConfig): def __init__(self, config: XCLIPTextConfig):
super().__init__() super().__init__()
...@@ -787,11 +755,13 @@ class XCLIPTextTransformer(nn.Module): ...@@ -787,11 +755,13 @@ class XCLIPTextTransformer(nn.Module):
# X_CLIP's text model uses causal mask, prepare it here. # X_CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
...@@ -127,39 +128,6 @@ XGLM_INPUTS_DOCSTRING = r""" ...@@ -127,39 +128,6 @@ XGLM_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class XGLMSinusoidalPositionalEmbedding(nn.Module): class XGLMSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -548,27 +516,6 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -548,27 +516,6 @@ class XGLMModel(XGLMPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
...@@ -624,14 +571,16 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -624,14 +571,16 @@ class XGLMModel(XGLMPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length) hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length)
hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
......
...@@ -1567,6 +1567,7 @@ from ...utils import ( ...@@ -1567,6 +1567,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -1608,37 +1609,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -1608,37 +1609,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embedding): class {{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(nn.Embedding):
""" """
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
...@@ -2280,7 +2250,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2280,7 +2250,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -2369,25 +2339,6 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2369,25 +2339,6 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -2491,12 +2442,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2491,12 +2442,12 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
......
...@@ -39,149 +39,6 @@ if is_torch_available(): ...@@ -39,149 +39,6 @@ if is_torch_available():
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
from transformers.models.llama.modeling_llama import AttnMaskConverter
@require_torch
class AttentionMaskTester(unittest.TestCase):
def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
mask_4d_values = mask_4d[:, 0][mask_indices]
is_inf = mask_4d_values == -float("inf")
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
assert torch.logical_or(is_inf, is_min).all()
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
if additional_mask is not None:
for bsz_idx, seq_idx in additional_mask:
mask_2d[bsz_idx, seq_idx] = 0
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif not mask_converter.is_causal and context is None:
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == 0
if 0 in mask_2d:
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
if q_len == 1 and mask_converter.sliding_window is None:
# no causal mask if q_len is 1
assert mask_4d is None
return
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
elif not mask_converter.is_causal and context is None:
assert (mask_4d != 0).sum().cpu().item() == 0
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
def compute_num_context_mask(self, kv_len, context, q_len):
# This function computes the # of attention tokens that are added for
# the sliding window
c_mask_len = kv_len - context
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
cut_mask_len = max(c_mask_len - q_len, 0)
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
return num_mask_triangle - num_cut_mask
def test_2d_to_4d_causal(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=False)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d_causal_sliding(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_causal_mask(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
def test_causal_mask_sliding(self):
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
class LlamaModelTester: class LlamaModelTester:
......
This diff is collapsed.
...@@ -329,6 +329,7 @@ IGNORE_SUBMODULES = [ ...@@ -329,6 +329,7 @@ IGNORE_SUBMODULES = [
"convert_pytorch_checkpoint_to_tf2", "convert_pytorch_checkpoint_to_tf2",
"modeling_flax_pytorch_utils", "modeling_flax_pytorch_utils",
"models.esm.openfold_utils", "models.esm.openfold_utils",
"modeling_attn_mask_utils",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment