Unverified Commit d7cb5e13 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Llama FA2] Re-add _expand_attention_mask and clean a couple things (#27074)

* clean

* clean llama

* fix more

* make style

* Apply suggestions from code review

* Apply suggestions from code review

* Update src/transformers/models/llama/modeling_llama.py

* Update src/transformers/models/llama/modeling_llama.py

* Apply suggestions from code review

* finish

* make style
parent 4864d08d
......@@ -148,7 +148,7 @@ class AttnMaskConverter:
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
......@@ -157,12 +157,16 @@ class AttnMaskConverter:
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
......@@ -182,8 +186,8 @@ class AttnMaskConverter:
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
......@@ -212,7 +216,8 @@ class AttnMaskConverter:
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
......@@ -837,7 +842,7 @@ class FalconFlashAttention2(FalconAttention):
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
......@@ -1154,8 +1159,6 @@ class FalconModel(FalconPreTrainedModel):
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
# Transformer blocks
......
......@@ -64,6 +64,24 @@ def _get_unpad_data(attention_mask):
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask"
)
return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask"
)
return AttnMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)
class AttnMaskConverter:
"""
A utility attention mask class that allows:
......@@ -122,7 +140,7 @@ class AttnMaskConverter:
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
......@@ -131,12 +149,16 @@ class AttnMaskConverter:
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
......@@ -156,8 +178,8 @@ class AttnMaskConverter:
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
......@@ -186,7 +208,8 @@ class AttnMaskConverter:
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
......@@ -555,7 +578,7 @@ class LlamaFlashAttention2(LlamaAttention):
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......@@ -669,7 +692,7 @@ class LlamaFlashAttention2(LlamaAttention):
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)
return attn_output
......@@ -739,8 +762,9 @@ class LlamaDecoderLayer(nn.Module):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
......@@ -914,8 +938,6 @@ class LlamaModel(LlamaPreTrainedModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
......
......@@ -113,7 +113,7 @@ class AttnMaskConverter:
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
......@@ -122,12 +122,16 @@ class AttnMaskConverter:
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
......@@ -147,8 +151,8 @@ class AttnMaskConverter:
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
......@@ -177,7 +181,8 @@ class AttnMaskConverter:
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
......@@ -645,7 +650,12 @@ class MistralFlashAttention2(MistralAttention):
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
else:
attn_output = flash_attn_func(
......@@ -654,7 +664,7 @@ class MistralFlashAttention2(MistralAttention):
value_states,
dropout,
softmax_scale=softmax_scale,
causal=True,
causal=self.is_causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
......@@ -903,8 +913,6 @@ class MistralModel(MistralPreTrainedModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment