Unverified Commit 33f98cfd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Remove ambiguous `padding_mask` and instead use a 2D->4D Attn Mask Mapper (#26792)



* [Attn Mask Converter] refactor attn mask

* up

* Apply suggestions from code review
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>

* improve

* rename

* better cache

* renaming

* improve more

* improve

* fix bug

* finalize

* make style & make fix-copies

* correct more

* start moving attention_mask

* fix llama

* improve falcon

* up

* improve more

* improve more

* Update src/transformers/models/owlv2/modeling_owlv2.py

* make style

* make style

* rename to converter

* Apply suggestions from code review

---------
Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
parent f09a081d
...@@ -560,7 +560,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): ...@@ -560,7 +560,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -76,9 +77,9 @@ def rotate_half(x): ...@@ -76,9 +77,9 @@ def rotate_half(x):
# Copied from transformers.models.llama.modeling_llama._get_unpad_data # Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return ( return (
...@@ -88,6 +89,143 @@ def _get_unpad_data(padding_mask): ...@@ -88,6 +89,143 @@ def _get_unpad_data(padding_mask):
) )
# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter
class AttnMaskConverter:
"""
A utility attention mask class that allows:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities # TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX. """Implementation of RotaryEmbedding from GPT-NeoX.
...@@ -311,6 +449,7 @@ class FalconAttention(nn.Module): ...@@ -311,6 +449,7 @@ class FalconAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
self.is_causal = True
if self.head_dim * self.num_heads != self.hidden_size: if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError( raise ValueError(
...@@ -431,8 +570,13 @@ class FalconAttention(nn.Module): ...@@ -431,8 +570,13 @@ class FalconAttention(nn.Module):
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim] # 3 x [batch_size, seq_length, num_heads, head_dim]
...@@ -465,9 +609,6 @@ class FalconAttention(nn.Module): ...@@ -465,9 +609,6 @@ class FalconAttention(nn.Module):
else: else:
present = None present = None
float_min = torch.finfo(query_layer.dtype).min
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype)
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
...@@ -482,16 +623,14 @@ class FalconAttention(nn.Module): ...@@ -482,16 +623,14 @@ class FalconAttention(nn.Module):
) )
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
) )
attention_scores = None attention_scores = None
else: else:
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2) attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim) attention_scores /= math.sqrt(self.head_dim)
attention_scores = F.softmax( attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
)
attn_output = attention_scores @ value_layer_ attn_output = attention_scores @ value_layer_
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
...@@ -517,12 +656,12 @@ class FalconAttention(nn.Module): ...@@ -517,12 +656,12 @@ class FalconAttention(nn.Module):
if input_dtype == torch.float16 or input_dtype == torch.bfloat16: if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32) attention_scores = attention_scores.to(torch.float32)
# Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
# adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically # adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically
# equivalent and more performant, but there might be a numerical difference. If you're reading this # equivalent and more performant, but there might be a numerical difference. If you're reading this
# and you'd like to experiment and maybe file a PR, feel free! # and you'd like to experiment and maybe file a PR, feel free!
attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
attention_logits *= self.inv_norm_factor attention_logits *= self.inv_norm_factor
attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype)
# [batch_size, num_heads, q_length, kv_length] # [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -563,8 +702,16 @@ class FalconFlashAttention2(FalconAttention): ...@@ -563,8 +702,16 @@ class FalconFlashAttention2(FalconAttention):
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim] # 3 x [batch_size, seq_length, num_heads, head_dim]
...@@ -630,7 +777,7 @@ class FalconFlashAttention2(FalconAttention): ...@@ -630,7 +777,7 @@ class FalconFlashAttention2(FalconAttention):
value_layer = value_layer.to(target_dtype) value_layer = value_layer.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout query_layer, key_layer, value_layer, attention_mask, query_length, dropout=attn_dropout
) )
attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
...@@ -643,7 +790,7 @@ class FalconFlashAttention2(FalconAttention): ...@@ -643,7 +790,7 @@ class FalconFlashAttention2(FalconAttention):
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward( def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
): ):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
...@@ -656,7 +803,7 @@ class FalconFlashAttention2(FalconAttention): ...@@ -656,7 +803,7 @@ class FalconFlashAttention2(FalconAttention):
Input key states to be passed to Flash Attention API Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`): value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`): attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens. position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*): dropout (`int`, *optional*):
...@@ -665,10 +812,10 @@ class FalconFlashAttention2(FalconAttention): ...@@ -665,10 +812,10 @@ class FalconFlashAttention2(FalconAttention):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if padding_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length query_states, key_states, value_states, attention_mask, query_length
) )
cu_seqlens_q, cu_seqlens_k = cu_seq_lens cu_seqlens_q, cu_seqlens_k = cu_seq_lens
...@@ -684,7 +831,7 @@ class FalconFlashAttention2(FalconAttention): ...@@ -684,7 +831,7 @@ class FalconFlashAttention2(FalconAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=self.is_causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
...@@ -696,8 +843,8 @@ class FalconFlashAttention2(FalconAttention): ...@@ -696,8 +843,8 @@ class FalconFlashAttention2(FalconAttention):
return attn_output return attn_output
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis( key_layer = index_first_axis(
...@@ -722,8 +869,8 @@ class FalconFlashAttention2(FalconAttention): ...@@ -722,8 +869,8 @@ class FalconFlashAttention2(FalconAttention):
query_layer = query_layer.squeeze(1) query_layer = query_layer.squeeze(1)
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return ( return (
query_layer, query_layer,
...@@ -752,7 +899,7 @@ class FalconMLP(nn.Module): ...@@ -752,7 +899,7 @@ class FalconMLP(nn.Module):
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, config: FalconConfig): def __init__(self, config: FalconConfig, attn_mask_converter=None):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
...@@ -786,8 +933,13 @@ class FalconDecoderLayer(nn.Module): ...@@ -786,8 +933,13 @@ class FalconDecoderLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
...@@ -806,7 +958,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -806,7 +958,7 @@ class FalconDecoderLayer(nn.Module):
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
padding_mask=padding_mask, **kwargs,
) )
attention_output = attn_outputs[0] attention_output = attn_outputs[0]
...@@ -1001,6 +1153,10 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1001,6 +1153,10 @@ class FalconModel(FalconPreTrainedModel):
# Embedding + LN Embedding # Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
...@@ -1015,37 +1171,6 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1015,37 +1171,6 @@ class FalconModel(FalconPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embeddings return self.word_embeddings
@staticmethod
def _prepare_attn_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# Create a causal mask
# The attention mask we receive as input should cover the whole extended sequence, including any past
# cache, so its shape should be [batch_size, seq_length + past_key_values_length]
# The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
raise ValueError(
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, seq_length = input_shape
if seq_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor): def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings self.word_embeddings = new_embeddings
...@@ -1114,19 +1239,16 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1114,19 +1239,16 @@ class FalconModel(FalconPreTrainedModel):
past_key_values_length = 0 past_key_values_length = 0
if past_key_values[0] is not None: if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
padding_mask = None
else:
attention_mask = attention_mask.to(hidden_states.device)
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
if self.use_alibi: if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) mask = (
torch.ones(
(batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long
)
if attention_mask is None
else attention_mask
)
alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype)
else: else:
alibi = None alibi = None
if position_ids is None: if position_ids is None:
...@@ -1136,11 +1258,20 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1136,11 +1258,20 @@ class FalconModel(FalconPreTrainedModel):
) )
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
causal_mask = self._prepare_attn_mask( if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask, # 2d mask is passed through the layers
input_shape=(batch_size, seq_length), attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
past_key_values_length=past_key_values_length, else:
) key_value_length = seq_length + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = self.attn_mask_converter.to_4d(
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states: if output_hidden_states:
...@@ -1159,22 +1290,20 @@ class FalconModel(FalconPreTrainedModel): ...@@ -1159,22 +1290,20 @@ class FalconModel(FalconPreTrainedModel):
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
alibi, alibi,
causal_mask, attention_mask,
position_ids, position_ids,
head_mask[i], head_mask[i],
padding_mask,
) )
else: else:
outputs = block( outputs = block(
hidden_states, hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=causal_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask[i], head_mask=head_mask[i],
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
alibi=alibi, alibi=alibi,
padding_mask=padding_mask,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch LLaMA model.""" """ PyTorch LLaMA model."""
import math import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -51,9 +52,9 @@ logger = logging.get_logger(__name__) ...@@ -51,9 +52,9 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
def _get_unpad_data(padding_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return ( return (
...@@ -63,37 +64,140 @@ def _get_unpad_data(padding_mask): ...@@ -63,37 +64,140 @@ def _get_unpad_data(padding_mask):
) )
# Copied from transformers.models.bart.modeling_bart._make_causal_mask class AttnMaskConverter:
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
""" """
Make causal mask used for bi-directional self-attention. A utility attention mask class that allows:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
""" """
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0: def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) self.is_causal = is_causal
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# Copied from transformers.models.bart.modeling_bart._expand_mask # If shape is not cached, create a new causal mask and cache it
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): input_shape = (batch_size, query_length)
""" past_key_values_length = key_value_length - query_length
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
""" # create causal mask
bsz, src_len = mask.size() # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
tgt_len = tgt_len if tgt_len is not None else src_len causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) return expanded_4d_mask
inverted_mask = 1.0 - expanded_mask def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
...@@ -272,6 +376,7 @@ class LlamaAttention(nn.Module): ...@@ -272,6 +376,7 @@ class LlamaAttention(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
...@@ -322,8 +427,13 @@ class LlamaAttention(nn.Module): ...@@ -322,8 +427,13 @@ class LlamaAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1: if self.config.pretraining_tp > 1:
...@@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
...@@ -492,7 +610,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -492,7 +610,7 @@ class LlamaFlashAttention2(LlamaAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
...@@ -504,7 +622,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -504,7 +622,7 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def _flash_attention_forward( def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
): ):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
...@@ -517,7 +635,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -517,7 +635,7 @@ class LlamaFlashAttention2(LlamaAttention):
Input key states to be passed to Flash Attention API Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`): value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`): attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens. position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*): dropout (`int`, *optional*):
...@@ -526,10 +644,10 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -526,10 +644,10 @@ class LlamaFlashAttention2(LlamaAttention):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
""" """
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if padding_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length query_states, key_states, value_states, attention_mask, query_length
) )
cu_seqlens_q, cu_seqlens_k = cu_seq_lens cu_seqlens_q, cu_seqlens_k = cu_seq_lens
...@@ -545,7 +663,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -545,7 +663,7 @@ class LlamaFlashAttention2(LlamaAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=self.is_causal,
) )
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
...@@ -556,8 +674,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -556,8 +674,8 @@ class LlamaFlashAttention2(LlamaAttention):
return attn_output return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis( key_layer = index_first_axis(
...@@ -582,8 +700,8 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -582,8 +700,8 @@ class LlamaFlashAttention2(LlamaAttention):
query_layer = query_layer.squeeze(1) query_layer = query_layer.squeeze(1)
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return ( return (
query_layer, query_layer,
...@@ -616,13 +734,13 @@ class LlamaDecoderLayer(nn.Module): ...@@ -616,13 +734,13 @@ class LlamaDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -631,6 +749,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -631,6 +749,10 @@ class LlamaDecoderLayer(nn.Module):
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
""" """
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states residual = hidden_states
...@@ -644,7 +766,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -644,7 +766,7 @@ class LlamaDecoderLayer(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask, **kwargs,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -791,6 +913,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -791,6 +913,10 @@ class LlamaModel(LlamaPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -805,30 +931,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -805,30 +931,6 @@ class LlamaModel(LlamaPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -854,18 +956,15 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -854,18 +956,15 @@ class LlamaModel(LlamaPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length = inputs_embeds.shape[:2]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
if past_key_values is not None: if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2] past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
...@@ -876,22 +975,23 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -876,22 +975,23 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None: if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = torch.ones( # 2d mask is passed through the layers
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
)
padding_mask = None
else: else:
if 0 in attention_mask: key_value_length = seq_length + past_key_values_length
padding_mask = attention_mask # 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = self.attn_mask_converter.to_4d(
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype
)
else: else:
padding_mask = None attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
attention_mask = self._prepare_decoder_attention_mask( )
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -917,7 +1017,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -917,7 +1017,7 @@ class LlamaModel(LlamaPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -932,7 +1032,6 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -932,7 +1032,6 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
""" PyTorch Mistral model.""" """ PyTorch Mistral model."""
import inspect import inspect
import math import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -53,10 +54,147 @@ logger = logging.get_logger(__name__) ...@@ -53,10 +54,147 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig" _CONFIG_FOR_DOC = "MistralConfig"
# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter
class AttnMaskConverter:
"""
A utility attention mask class that allows:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.llama.modeling_llama._get_unpad_data # Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask): def _get_unpad_data(attention_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item() max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return ( return (
...@@ -66,33 +204,6 @@ def _get_unpad_data(padding_mask): ...@@ -66,33 +204,6 @@ def _get_unpad_data(padding_mask):
) )
def _make_sliding_window_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
bsz, tgt_len = input_ids_shape
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
mask = torch.triu(mask, diagonal=-sliding_window)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask # Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
""" """
...@@ -223,6 +334,7 @@ class MistralAttention(nn.Module): ...@@ -223,6 +334,7 @@ class MistralAttention(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
...@@ -251,8 +363,12 @@ class MistralAttention(nn.Module): ...@@ -251,8 +363,12 @@ class MistralAttention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
...@@ -332,8 +448,15 @@ class MistralFlashAttention2(MistralAttention): ...@@ -332,8 +448,15 @@ class MistralFlashAttention2(MistralAttention):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, **kwargs,
): ):
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
...@@ -385,9 +508,9 @@ class MistralFlashAttention2(MistralAttention): ...@@ -385,9 +508,9 @@ class MistralFlashAttention2(MistralAttention):
past_key_value = (past_key, past_value) past_key_value = (past_key, past_value)
if padding_mask is not None: if attention_mask is not None:
padding_mask = padding_mask[:, slicing_tokens:] attention_mask = attention_mask[:, slicing_tokens:]
padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1) attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2)
...@@ -433,7 +556,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -433,7 +556,7 @@ class MistralFlashAttention2(MistralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
padding_mask, attention_mask,
q_len, q_len,
dropout=dropout_rate, dropout=dropout_rate,
use_sliding_windows=use_sliding_windows, use_sliding_windows=use_sliding_windows,
...@@ -452,7 +575,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -452,7 +575,7 @@ class MistralFlashAttention2(MistralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
padding_mask, attention_mask,
query_length, query_length,
dropout=0.0, dropout=0.0,
softmax_scale=None, softmax_scale=None,
...@@ -469,7 +592,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -469,7 +592,7 @@ class MistralFlashAttention2(MistralAttention):
Input key states to be passed to Flash Attention API Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`): value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`): attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens. position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*): dropout (`int`, *optional*):
...@@ -480,10 +603,10 @@ class MistralFlashAttention2(MistralAttention): ...@@ -480,10 +603,10 @@ class MistralFlashAttention2(MistralAttention):
Whether to activate sliding window attention. Whether to activate sliding window attention.
""" """
# Contains at least one padding token in the sequence # Contains at least one padding token in the sequence
if padding_mask is not None: if attention_mask is not None:
batch_size = query_states.shape[0] batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, padding_mask, query_length query_states, key_states, value_states, attention_mask, query_length
) )
cu_seqlens_q, cu_seqlens_k = cu_seq_lens cu_seqlens_q, cu_seqlens_k = cu_seq_lens
...@@ -500,7 +623,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -500,7 +623,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=self.is_causal,
) )
else: else:
attn_output_unpad = flash_attn_varlen_func( attn_output_unpad = flash_attn_varlen_func(
...@@ -513,7 +636,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -513,7 +636,7 @@ class MistralFlashAttention2(MistralAttention):
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=True, causal=self.is_causal,
window_size=(self.config.sliding_window, self.config.sliding_window), window_size=(self.config.sliding_window, self.config.sliding_window),
) )
...@@ -536,16 +659,16 @@ class MistralFlashAttention2(MistralAttention): ...@@ -536,16 +659,16 @@ class MistralFlashAttention2(MistralAttention):
return attn_output return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask # On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place # by slicing it on the proper place
if kv_seq_len != padding_mask.shape[-1]: if kv_seq_len != attention_mask.shape[-1]:
padding_mask_num_tokens = padding_mask.shape[-1] attention_mask_num_tokens = attention_mask.shape[-1]
padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :] attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
...@@ -566,8 +689,8 @@ class MistralFlashAttention2(MistralAttention): ...@@ -566,8 +689,8 @@ class MistralFlashAttention2(MistralAttention):
query_layer = query_layer.squeeze(1) query_layer = query_layer.squeeze(1)
else: else:
# The -q_len: slice assumes left padding. # The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:] attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return ( return (
query_layer, query_layer,
...@@ -600,13 +723,17 @@ class MistralDecoderLayer(nn.Module): ...@@ -600,13 +723,17 @@ class MistralDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
padding_mask: Optional[torch.Tensor] = None, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -628,7 +755,6 @@ class MistralDecoderLayer(nn.Module): ...@@ -628,7 +755,6 @@ class MistralDecoderLayer(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -775,6 +901,10 @@ class MistralModel(MistralPreTrainedModel): ...@@ -775,6 +901,10 @@ class MistralModel(MistralPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -789,32 +919,6 @@ class MistralModel(MistralPreTrainedModel): ...@@ -789,32 +919,6 @@ class MistralModel(MistralPreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_sliding_window_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
...@@ -865,23 +969,13 @@ class MistralModel(MistralPreTrainedModel): ...@@ -865,23 +969,13 @@ class MistralModel(MistralPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
padding_mask = None
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
elif 0 in attention_mask:
padding_mask = attention_mask
if ( if (
padding_mask is not None attention_mask is not None
and hasattr(self.config, "_flash_attn_2_enabled") and hasattr(self.config, "_flash_attn_2_enabled")
and self.config._flash_attn_2_enabled and self.config._flash_attn_2_enabled
and past_key_values is not None and past_key_values is not None
): ):
is_padding_right = padding_mask[:, -1].sum().item() != batch_size is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right: if is_padding_right:
raise ValueError( raise ValueError(
"You are attempting to perform batched generation with padding_side='right'" "You are attempting to perform batched generation with padding_side='right'"
...@@ -889,13 +983,20 @@ class MistralModel(MistralPreTrainedModel): ...@@ -889,13 +983,20 @@ class MistralModel(MistralPreTrainedModel):
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
) )
attention_mask = self._prepare_decoder_attention_mask( if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask, # 2d mask is passed through the layers
(batch_size, seq_length), attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
inputs_embeds, else:
past_key_values_length, key_value_length = seq_length + past_key_values_length
sliding_window=self.config.sliding_window, # 4d mask is passed through the layers
) if attention_mask is not None:
attention_mask = self.attn_mask_converter.to_4d(
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -922,7 +1023,7 @@ class MistralModel(MistralPreTrainedModel): ...@@ -922,7 +1023,7 @@ class MistralModel(MistralPreTrainedModel):
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return module(*inputs, past_key_value, output_attentions)
return custom_forward return custom_forward
...@@ -940,7 +1041,6 @@ class MistralModel(MistralPreTrainedModel): ...@@ -940,7 +1041,6 @@ class MistralModel(MistralPreTrainedModel):
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
...@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
config: PersimmonConfig config: PersimmonConfig
""" """
# Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps
def __init__(self, config: PersimmonConfig): def __init__(self, config: PersimmonConfig):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
......
...@@ -39,6 +39,149 @@ if is_torch_available(): ...@@ -39,6 +39,149 @@ if is_torch_available():
LlamaModel, LlamaModel,
LlamaTokenizer, LlamaTokenizer,
) )
from transformers.models.llama.modeling_llama import AttnMaskConverter
@require_torch
class AttentionMaskTester(unittest.TestCase):
def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d):
mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len))
mask_4d_values = mask_4d[:, 0][mask_indices]
is_inf = mask_4d_values == -float("inf")
is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min
assert torch.logical_or(is_inf, is_min).all()
def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3):
mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long)
if additional_mask is not None:
for bsz_idx, seq_idx in additional_mask:
mask_2d[bsz_idx, seq_idx] = 0
mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len)
assert mask_4d.shape == (bsz, 1, q_len, kv_len)
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif not mask_converter.is_causal and context is None:
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == 0
if 0 in mask_2d:
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
if 0 not in mask_2d:
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
if 0 in mask_2d:
# at least causal mask + maybe more
assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked
self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d)
def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3):
mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device)
if q_len == 1 and mask_converter.sliding_window is None:
# no causal mask if q_len is 1
assert mask_4d is None
return
context = mask_converter.sliding_window
if mask_converter.is_causal and context is None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = bsz * (q_len * (q_len - 1) // 2)
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
elif not mask_converter.is_causal and context is None:
assert (mask_4d != 0).sum().cpu().item() == 0
elif mask_converter.is_causal and context is not None:
# k * (k+1) / 2 tokens are masked in triangualar masks
num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len)
num_tokens_masked = bsz * num_tokens_masked
assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked
def compute_num_context_mask(self, kv_len, context, q_len):
# This function computes the # of attention tokens that are added for
# the sliding window
c_mask_len = kv_len - context
num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2
cut_mask_len = max(c_mask_len - q_len, 0)
num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2
return num_mask_triangle - num_cut_mask
def test_2d_to_4d_causal(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=False)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_2d_to_4d_causal_sliding(self):
torch.ones((3, 7), device=torch_device, dtype=torch.long)
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5)
# auto-regressive use case
self.check_to_4d(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_4d(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_4d(mask_converter, q_len=7, kv_len=7)
# same with extra attention masks
self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)])
def test_causal_mask(self):
mask_converter = AttnMaskConverter(is_causal=True)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
def test_causal_mask_sliding(self):
mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3)
# auto-regressive use case
self.check_to_causal(mask_converter, q_len=1, kv_len=7)
# special auto-regressive case
self.check_to_causal(mask_converter, q_len=3, kv_len=7)
# non auto-regressive case
self.check_to_causal(mask_converter, q_len=7, kv_len=7)
class LlamaModelTester: class LlamaModelTester:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment