Unverified Commit ac589375 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Attention Mask] Refactor all encoder-decoder attention mask (#27086)



* [FA2 Bart] Add FA2 to all Bart-like

* better

* Refactor attention mask

* remove all customized atteniton logic

* format

* mass rename

* replace _expand_mask

* replace _expand_mask

* mass rename

* add pt files

* mass replace & rename

* mass replace & rename

* mass replace & rename

* mass replace & rename

* Update src/transformers/models/idefics/modeling_idefics.py

* fix more

* clean more

* fix more

* make style

* fix again

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* finish

* Apply suggestions from code review

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* small fix mistral

* finish

* finish

* finish

* finish

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 29c74f58
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -230,39 +231,6 @@ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch. ...@@ -230,39 +231,6 @@ def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.
return -input.log_prob(target) return -input.log_prob(target)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer
class InformerSinusoidalPositionalEmbedding(nn.Embedding): class InformerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -1184,7 +1152,7 @@ class InformerEncoder(InformerPreTrainedModel): ...@@ -1184,7 +1152,7 @@ class InformerEncoder(InformerPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1274,29 +1242,6 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1274,29 +1242,6 @@ class InformerDecoder(InformerPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
...@@ -1380,14 +1325,16 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1380,14 +1325,16 @@ class InformerDecoder(InformerPreTrainedModel):
# past_key_values_length # past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
hidden_states = self.value_embedding(inputs_embeds) hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length)
......
...@@ -26,6 +26,7 @@ from torch import nn ...@@ -26,6 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
...@@ -74,22 +75,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -74,22 +75,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): def _prepare_4d_attention_mask_inverted(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
...@@ -1838,7 +1824,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1838,7 +1824,7 @@ class LEDEncoder(LEDPreTrainedModel):
# convert attention_mask to float # convert attention_mask to float
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf" # [bsz, seq_len] -> [bsz, seq_len]; 1 -> 0.0; 0 -> "-inf"
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)[:, 0, 0, :] attention_mask = _prepare_4d_attention_mask_inverted(attention_mask, inputs_embeds.dtype)[:, 0, 0, :]
# get masking tensors # get masking tensors
is_index_masked = attention_mask < 0 is_index_masked = attention_mask < 0
...@@ -2077,20 +2063,22 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2077,20 +2063,22 @@ class LEDDecoder(LEDPreTrainedModel):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask( combined_attention_mask = _create_4d_causal_attention_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length input_shape, inputs_embeds.dtype, inputs_embeds.device, past_key_values_length=past_key_values_length
).to(self.device) )
if attention_mask is not None and combined_attention_mask is not None: if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask_inverted(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask_inverted(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
......
...@@ -29,6 +29,7 @@ from torch import nn ...@@ -29,6 +29,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
...@@ -66,163 +67,22 @@ def _get_unpad_data(attention_mask): ...@@ -66,163 +67,22 @@ def _get_unpad_data(attention_mask):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn( warnings.warn(
"Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask" "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask"
) )
return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
): ):
warnings.warn( warnings.warn(
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask" "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
) )
return AttnMaskConverter._make_causal_mask( return AttentionMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
) )
class AttnMaskConverter:
"""
A utility attention mask class that allows:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
...@@ -933,8 +793,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -933,8 +793,6 @@ class LlamaModel(LlamaPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -998,16 +856,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -998,16 +856,10 @@ class LlamaModel(LlamaPreTrainedModel):
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else: else:
key_value_length = seq_length + past_key_values_length
# 4d mask is passed through the layers # 4d mask is passed through the layers
if attention_mask is not None: attention_mask = _prepare_4d_causal_attention_mask(
attention_mask = self.attn_mask_converter.to_4d( attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype )
)
else:
attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss ...@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -71,39 +72,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -71,39 +72,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
...@@ -790,7 +758,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -790,7 +758,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -991,25 +959,16 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -991,25 +959,16 @@ class M2M100Decoder(M2M100PreTrainedModel):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None combined_attention_mask = _prepare_4d_causal_attention_mask(
if input_shape[-1] > 1: attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = _make_causal_mask( )
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
......
...@@ -26,6 +26,7 @@ from torch import nn ...@@ -26,6 +26,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -73,39 +74,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -73,39 +74,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class MarianSinusoidalPositionalEmbedding(nn.Embedding): class MarianSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -760,7 +728,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -760,7 +728,7 @@ class MarianEncoder(MarianPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -851,30 +819,6 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -851,30 +819,6 @@ class MarianDecoder(MarianPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -979,14 +923,16 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -979,14 +923,16 @@ class MarianDecoder(MarianPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
......
...@@ -245,21 +245,6 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput): ...@@ -245,21 +245,6 @@ class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def sample_point( def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
......
...@@ -25,6 +25,7 @@ from torch import Tensor, nn ...@@ -25,6 +25,7 @@ from torch import Tensor, nn
from ... import AutoBackbone from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -702,21 +703,6 @@ class DetrDecoderLayer(nn.Module): ...@@ -702,21 +703,6 @@ class DetrDecoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
class DetrDecoder(nn.Module): class DetrDecoder(nn.Module):
""" """
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
...@@ -817,18 +803,12 @@ class DetrDecoder(nn.Module): ...@@ -817,18 +803,12 @@ class DetrDecoder(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
combined_attention_mask = None
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# optional intermediate hidden states # optional intermediate hidden states
intermediate = () if self.config.auxiliary_loss else None intermediate = () if self.config.auxiliary_loss else None
...@@ -851,7 +831,7 @@ class DetrDecoder(nn.Module): ...@@ -851,7 +831,7 @@ class DetrDecoder(nn.Module):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, None,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
None, None,
...@@ -860,7 +840,7 @@ class DetrDecoder(nn.Module): ...@@ -860,7 +840,7 @@ class DetrDecoder(nn.Module):
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=None,
object_queries=object_queries, object_queries=object_queries,
query_position_embeddings=query_position_embeddings, query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -78,39 +79,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): ...@@ -78,39 +79,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
return prev_output_tokens return prev_output_tokens
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
class MBartLearnedPositionalEmbedding(nn.Embedding): class MBartLearnedPositionalEmbedding(nn.Embedding):
""" """
...@@ -798,7 +766,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -798,7 +766,7 @@ class MBartEncoder(MBartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -896,30 +864,6 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -896,30 +864,6 @@ class MBartDecoder(MBartPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -1026,14 +970,16 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1026,14 +970,16 @@ class MBartDecoder(MBartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -30,6 +30,7 @@ from torch import nn ...@@ -30,6 +30,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -54,148 +55,6 @@ logger = logging.get_logger(__name__) ...@@ -54,148 +55,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig" _CONFIG_FOR_DOC = "MistralConfig"
# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter
class AttnMaskConverter:
"""
A utility attention mask class that allows:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.llama.modeling_llama._get_unpad_data # Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
...@@ -209,21 +68,6 @@ def _get_unpad_data(attention_mask): ...@@ -209,21 +68,6 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
class MistralRMSNorm(nn.Module): class MistralRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
...@@ -908,8 +752,6 @@ class MistralModel(MistralPreTrainedModel): ...@@ -908,8 +752,6 @@ class MistralModel(MistralPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -978,7 +820,6 @@ class MistralModel(MistralPreTrainedModel): ...@@ -978,7 +820,6 @@ class MistralModel(MistralPreTrainedModel):
attention_mask is not None attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled") and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled and self.config._flash_attn_2_enabled
and past_key_values is not None
): ):
is_padding_right = attention_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
...@@ -992,16 +833,14 @@ class MistralModel(MistralPreTrainedModel): ...@@ -992,16 +833,14 @@ class MistralModel(MistralPreTrainedModel):
# 2d mask is passed through the layers # 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else: else:
key_value_length = seq_length + past_key_values_length
# 4d mask is passed through the layers # 4d mask is passed through the layers
if attention_mask is not None: attention_mask = _prepare_4d_causal_attention_mask(
attention_mask = self.attn_mask_converter.to_4d( attention_mask,
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype (batch_size, seq_length),
) inputs_embeds,
else: past_key_values_length,
attention_mask = self.attn_mask_converter.to_causal_4d( sliding_window=self.config.sliding_window,
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device )
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
...@@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss ...@@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
...@@ -55,38 +56,6 @@ MPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -55,38 +56,6 @@ MPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
seq_ids = torch.arange(target_length, device=device)
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
if past_key_values_length > 0:
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask
# Copied from transformers.models.bloom.modeling_bloom._expand_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None): def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
r""" r"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
...@@ -412,34 +381,6 @@ class MptModel(MptPreTrainedModel): ...@@ -412,34 +381,6 @@ class MptModel(MptPreTrainedModel):
def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None): def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device) return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
raise ValueError(
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.wte = new_embeddings self.wte = new_embeddings
...@@ -508,13 +449,12 @@ class MptModel(MptPreTrainedModel): ...@@ -508,13 +449,12 @@ class MptModel(MptPreTrainedModel):
alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
causal_mask = self._prepare_attn_mask( causal_mask = _prepare_4d_causal_attention_mask(
attention_mask, attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
) )
causal_mask = causal_mask.bool()
for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)): for block, layer_past in zip(self.blocks, past_key_values):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
......
...@@ -28,6 +28,7 @@ from ...activations import ACT2FN ...@@ -28,6 +28,7 @@ from ...activations import ACT2FN
from ...generation.configuration_utils import GenerationConfig from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList
from ...generation.stopping_criteria import StoppingCriteriaList from ...generation.stopping_criteria import StoppingCriteriaList
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -80,39 +81,6 @@ class MusicgenUnconditionalInput(ModelOutput): ...@@ -80,39 +81,6 @@ class MusicgenUnconditionalInput(ModelOutput):
guidance_scale: float = None guidance_scale: float = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
""" """
...@@ -706,30 +674,6 @@ class MusicgenDecoder(MusicgenPreTrainedModel): ...@@ -706,30 +674,6 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MUSICGEN_DECODER_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -773,14 +717,16 @@ class MusicgenDecoder(MusicgenPreTrainedModel): ...@@ -773,14 +717,16 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -89,39 +90,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -89,39 +90,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP
class MvpLearnedPositionalEmbedding(nn.Embedding): class MvpLearnedPositionalEmbedding(nn.Embedding):
""" """
...@@ -918,7 +886,7 @@ class MvpEncoder(MvpPreTrainedModel): ...@@ -918,7 +886,7 @@ class MvpEncoder(MvpPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -1033,27 +1001,6 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1033,27 +1001,6 @@ class MvpDecoder(MvpPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -1160,14 +1107,16 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1160,14 +1107,16 @@ class MvpDecoder(MvpPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
...@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss ...@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
MoEModelOutput, MoEModelOutput,
MoEModelOutputWithPastAndCrossAttentions, MoEModelOutputWithPastAndCrossAttentions,
...@@ -75,39 +76,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -75,39 +76,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
""" """
...@@ -1125,7 +1093,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): ...@@ -1125,7 +1093,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_router_probs = () if output_router_logits else None all_router_probs = () if output_router_logits else None
...@@ -1342,25 +1310,16 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): ...@@ -1342,25 +1310,16 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None combined_attention_mask = _prepare_4d_causal_attention_mask(
if input_shape[-1] > 1: attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = _make_causal_mask( )
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
......
...@@ -21,6 +21,7 @@ from torch import nn ...@@ -21,6 +21,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
...@@ -63,38 +64,6 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -63,38 +64,6 @@ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
""" """
This module learns positional embeddings up to a fixed maximum size. This module learns positional embeddings up to a fixed maximum size.
...@@ -525,30 +494,6 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -525,30 +494,6 @@ class OPTDecoder(OPTPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -643,7 +588,7 @@ class OPTDecoder(OPTPreTrainedModel): ...@@ -643,7 +588,7 @@ class OPTDecoder(OPTPreTrainedModel):
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)" f"{mask_seq_length} (sum of the lengths of current and past inputs)"
) )
causal_attention_mask = self._prepare_decoder_attention_mask( causal_attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
pos_embeds = self.embed_positions(attention_mask, past_key_values_length) pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch OWLv2 model.""" """ PyTorch OWLv2 model."""
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
...@@ -25,6 +24,7 @@ import torch.utils.checkpoint ...@@ -25,6 +24,7 @@ import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -53,21 +53,6 @@ OWLV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -53,21 +53,6 @@ OWLV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2 # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlv2
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
...@@ -790,24 +775,6 @@ class Owlv2Encoder(nn.Module): ...@@ -790,24 +775,6 @@ class Owlv2Encoder(nn.Module):
) )
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2 # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTTextTransformer with OWLVIT->OWLV2,OwlViT->Owlv2
class Owlv2TextTransformer(nn.Module): class Owlv2TextTransformer(nn.Module):
def __init__(self, config: Owlv2TextConfig): def __init__(self, config: Owlv2TextConfig):
...@@ -845,11 +812,13 @@ class Owlv2TextTransformer(nn.Module): ...@@ -845,11 +812,13 @@ class Owlv2TextTransformer(nn.Module):
# num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries
# OWLV2's text model uses causal mask, prepare it here. # OWLV2's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch OWL-ViT model.""" """ PyTorch OWL-ViT model."""
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
...@@ -25,6 +24,7 @@ import torch.utils.checkpoint ...@@ -25,6 +24,7 @@ import torch.utils.checkpoint
from torch import Tensor, nn from torch import Tensor, nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
...@@ -54,21 +54,6 @@ OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -54,21 +54,6 @@ OWLVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
] ]
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit # Copied from transformers.models.clip.modeling_clip.contrastive_loss with clip->owlvit
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
...@@ -779,24 +764,6 @@ class OwlViTEncoder(nn.Module): ...@@ -779,24 +764,6 @@ class OwlViTEncoder(nn.Module):
) )
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
class OwlViTTextTransformer(nn.Module): class OwlViTTextTransformer(nn.Module):
def __init__(self, config: OwlViTTextConfig): def __init__(self, config: OwlViTTextConfig):
super().__init__() super().__init__()
...@@ -833,11 +800,13 @@ class OwlViTTextTransformer(nn.Module): ...@@ -833,11 +800,13 @@ class OwlViTTextTransformer(nn.Module):
# num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries # num_samples, seq_len = input_shape where num_samples = batch_size * num_max_text_queries
# OWLVIT's text model uses causal mask, prepare it here. # OWLVIT's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds=hidden_states, inputs_embeds=hidden_states,
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -72,39 +73,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -72,39 +73,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
class PegasusSinusoidalPositionalEmbedding(nn.Embedding): class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -773,7 +741,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -773,7 +741,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -871,30 +839,6 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -871,30 +839,6 @@ class PegasusDecoder(PegasusPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def resize_position_embeddings(self, new_num_position_embeddings: int): def resize_position_embeddings(self, new_num_position_embeddings: int):
""" """
Resizes position embeddings matrix of the model if `new_num_position_embeddings != Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
...@@ -1028,14 +972,16 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1028,14 +972,16 @@ class PegasusDecoder(PegasusPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
......
...@@ -25,6 +25,7 @@ from torch import nn ...@@ -25,6 +25,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -90,39 +91,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -90,39 +91,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class PegasusXSinusoidalPositionalEmbedding(nn.Module): class PegasusXSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -1138,55 +1106,6 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1138,55 +1106,6 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(self.config.d_model)
self.embed_positions.to(self.device)
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.embed_positions
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1278,14 +1197,16 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1278,14 +1197,16 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(inputs_embeds, past_key_values_length) positions = self.embed_positions(inputs_embeds, past_key_values_length)
......
...@@ -27,6 +27,7 @@ from torch import nn ...@@ -27,6 +27,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
...@@ -38,39 +39,6 @@ logger = logging.get_logger(__name__) ...@@ -38,39 +39,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig" _CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -563,30 +531,6 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -563,30 +531,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -639,7 +583,7 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -639,7 +583,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
attention_mask = torch.ones( attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
) )
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -77,39 +78,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): ...@@ -77,39 +78,6 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
return prev_output_tokens return prev_output_tokens
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
class PLBartLearnedPositionalEmbedding(nn.Embedding): class PLBartLearnedPositionalEmbedding(nn.Embedding):
""" """
...@@ -776,7 +744,7 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -776,7 +744,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
encoder_states = () if output_hidden_states else None encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
...@@ -873,29 +841,6 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -873,29 +841,6 @@ class PLBartDecoder(PLBartPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
...@@ -1002,14 +947,16 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1002,14 +947,16 @@ class PLBartDecoder(PLBartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# embed positions # embed positions
positions = self.embed_positions(input, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment