DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def__init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (DeepseekV2RMSNorm handles it correctly)
input_dtype=query_states.dtype
ifinput_dtype==torch.float32:
# Handle the case where the model is quantized
ifhasattr(self.config,"_pre_quantization_dtype"):
target_dtype=self.config._pre_quantization_dtype
eliftorch.is_autocast_enabled():
target_dtype=torch.get_autocast_gpu_dtype()
else:
target_dtype=self.q_a_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states=query_states.to(target_dtype)
key_states=key_states.to(target_dtype)
value_states=value_states.to(target_dtype)
attn_output=self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
softmax_scale=self.softmax_scale,
)
ifself.q_head_dim!=self.v_head_dim:
attn_output=attn_output[:,:,:,:self.v_head_dim]
attn_output=attn_output.reshape(
bsz,q_len,self.num_heads*self.v_head_dim
).contiguous()
attn_output=self.o_proj(attn_output)
ifnotoutput_attentions:
attn_weights=None
returnattn_output,attn_weights,past_key_value
def_flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
position_ids,
dropout=0.0,
softmax_scale=None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
ifnotself._flash_attn_uses_top_left_mask:
causal=self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
causal=self.is_causalandquery_length!=1
# Contains at least one padding token in the sequence
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
classQwen2MoeFlashAttention2(Qwen2MoeAttention):
"""
Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
where it needs to correctly call the public API of flash attention and deal with padding tokens
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
config.max_window_layers layers.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def__init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
ifnotself._flash_attn_uses_top_left_mask:
causal=self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal=self.is_causalandquery_length!=1
# Decide whether to use SWA or not by layer index.
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
# but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set
# but can't get using __getattr__, typically these attr is build in attr of the class, so class.attr does not
# call __getattr__.
# Example:
# ...import torch
# ...l=torch.nn.Linear(100,200)
# ...l.out_features # 200
# ...l.__getattr__("out_features") # AttributeError: 'Linear' object has no attribute 'out_features'
try:
returnobject.__getattribute__(self,name)# if this attr belongs to BaseInjectedModule
except:
ifname=="orig_module":
returnnn.Module.__getattr__(self,"orig_module")
try:
returnnn.Module.__getattr__(self,"orig_module").__getattr__(name)# if this attr belongs to orig_module
except:
returnsuper(nn.Module,nn.Module.__getattr__(self,"orig_module")).__getattribute__(name)# if this attr belongs to orig_module but not in nn.Module.__dict__
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.")
print(f"This linear module's in_features or out_features is not divisible by GPTQ_MARLIN_MIN_THREAD_N({GPTQ_MARLIN_MIN_THREAD_N}), using QuantizedLinearTorch instead.")