Unverified Commit 33f98cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Remove ambiguous `padding_mask` and instead use a 2D->4D Attn Mask Mapper (#26792)



* [Attn Mask Converter] refactor attn mask

* up

* Apply suggestions from code review
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>

* improve

* rename

* better cache

* renaming

* improve more

* improve

* fix bug

* finalize

* make style & make fix-copies

* correct more

* start moving attention_mask

* fix llama

* improve falcon

* up

* improve more

* improve more

* Update src/transformers/models/owlv2/modeling_owlv2.py

* make style

* make style

* rename to converter

* Apply suggestions from code review

---------
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
parent f09a081d
...@@ -560,7 +560,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): ...@@ -560,7 +560,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch LLaMA model.""" """ PyTorch LLaMA model."""
import math import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -51,9 +52,9 @@ logger = logging.get_logger(__name__) ...@@ -51,9 +52,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(padding_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return ( return (
...@@ -63,10 +64,106 @@ def _get_unpad_data(padding_mask): ...@@ -63,10 +64,106 @@ def _get_unpad_data(padding_mask):
) )
# Copied from transformers.models.bart.modeling_bart._make_causal_mask class AttnMaskConverter:
def _make_causal_mask( """
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 A utility attention mask class that allows:
): - Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
""" """
Make causal mask used for bi-directional self-attention. Make causal mask used for bi-directional self-attention.
""" """
...@@ -74,15 +171,22 @@ def _make_causal_mask( ...@@ -74,15 +171,22 @@ def _make_causal_mask(
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device) mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype) mask = mask.to(dtype)
if past_key_values_length > 0: if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
""" """
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" """
...@@ -272,6 +376,7 @@ class LlamaAttention(nn.Module): ...@@ -272,6 +376,7 @@ class LlamaAttention(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
...@@ -322,8 +427,13 @@ class LlamaAttention(nn.Module): ...@@ -322,8 +427,13 @@ class LlamaAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
...@@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -492,7 +610,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -492,7 +610,7 @@ class LlamaFlashAttention2(LlamaAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -504,7 +622,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -504,7 +622,7 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward( def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
): ):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
...@@ -517,7 +635,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -517,7 +635,7 @@ class LlamaFlashAttention2(LlamaAttention):
Input key states to be passed to Flash Attention API Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`): value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`): attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens. position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*): dropout (`int`, *optional*):
...@@ -526,10 +644,10 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -526,10 +644,10 @@ class LlamaFlashAttention2(LlamaAttention):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if padding_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length query_states, key_states, value_states, attention_mask, query_length
) )
cu_seqlens_q, cu_seqlens_k = cu_seq_lens cu_seqlens_q, cu_seqlens_k = cu_seq_lens
...@@ -545,7 +663,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -545,7 +663,7 @@ class LlamaFlashAttention2(LlamaAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=self.is_causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
...@@ -556,8 +674,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -556,8 +674,8 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis( key_layer = index_first_axis(
...@@ -582,8 +700,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -582,8 +700,8 @@ class LlamaFlashAttention2(LlamaAttention):
query_layer = query_layer.squeeze(1) query_layer = query_layer.squeeze(1)
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return ( return (
query_layer, query_layer,
...@@ -616,13 +734,13 @@ class LlamaDecoderLayer(nn.Module): ...@@ -616,13 +734,13 @@ class LlamaDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -631,6 +749,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -631,6 +749,10 @@ class LlamaDecoderLayer(nn.Module):
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
""" """
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
...@@ -644,7 +766,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -644,7 +766,7 @@ class LlamaDecoderLayer(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -791,6 +913,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -791,6 +913,10 @@ class LlamaModel(LlamaPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -805,30 +931,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -805,30 +931,6 @@ class LlamaModel(LlamaPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -854,18 +956,15 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -854,18 +956,15 @@ class LlamaModel(LlamaPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
if past_key_values is not None: if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2] past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
...@@ -876,22 +975,23 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -876,22 +975,23 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None: if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = torch.ones( # 2d mask is passed through the layers
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
)
padding_mask = None
else: else:
if 0 in attention_mask: key_value_length = seq_length + past_key_values_length
padding_mask = attention_mask # 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = self.attn_mask_converter.to_4d(
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype
)
else: else:
padding_mask = None attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
) )
# embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -917,7 +1017,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -917,7 +1017,7 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -932,7 +1032,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -932,7 +1032,6 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
...@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
config: PersimmonConfig config: PersimmonConfig
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps
def __init__(self, config: PersimmonConfig): def __init__(self, config: PersimmonConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
......
...@@ -39,6 +39,149 @@ if is_torch_available(): ...@@ -39,6 +39,149 @@ if is_torch_available():
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
from transformers.models.llama.modeling_llama import AttnMaskConverter
@require_torch
class AttentionMaskTester(unittest.TestCase):
def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
mask_4d_values = mask_4d[:, 0][mask_indices]
is_inf = mask_4d_values == -float("inf")
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
assert torch.logical_or(is_inf, is_min).all()
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
if additional_mask is not None:
for bsz_idx, seq_idx in additional_mask:
mask_2d[bsz_idx, seq_idx] = 0
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif not mask_converter.is_causal and context is None:
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == 0
if 0 in mask_2d:
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
if q_len == 1 and mask_converter.sliding_window is None:
# no causal mask if q_len is 1
assert mask_4d is None
return
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
elif not mask_converter.is_causal and context is None:
assert (mask_4d != 0).sum().cpu().item() == 0
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
def compute_num_context_mask(self, kv_len, context, q_len):
# This function computes the # of attention tokens that are added for
# the sliding window
c_mask_len = kv_len - context
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
cut_mask_len = max(c_mask_len - q_len, 0)
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
return num_mask_triangle - num_cut_mask
def test_2d_to_4d_causal(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=False)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d_causal_sliding(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_causal_mask(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
def test_causal_mask_sliding(self):
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
class LlamaModelTester: class LlamaModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment