Commit ef8f16f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update v1 mla

parent 660af62e
...@@ -190,6 +190,7 @@ from dataclasses import dataclass ...@@ -190,6 +190,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch import torch
import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
...@@ -642,6 +643,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -642,6 +643,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.flash_attn_varlen_func = \ self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func, functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# For MLA the v head dim is smaller than qk head dim so we pad out # For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do # v with 0s to match the qk head dim for attention backends that do
...@@ -649,7 +652,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -649,7 +652,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# We don't need to pad V if we are on a hopper system with FA3 # We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not ( self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3 self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9) and current_platform.get_device_capability()[0] == 9
and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 )
def _flash_attn_varlen_diff_headdims(self, def _flash_attn_varlen_diff_headdims(self,
q, q,
...@@ -660,8 +664,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -660,8 +664,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
**kwargs): **kwargs):
maybe_padded_v = v maybe_padded_v = v
if self._pad_v: if self._pad_v:
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v = torch.nn.functional.pad( maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0) v, [0, q.shape[-1] - v.shape[-1]- 32], value=0)
maybe_padded_v = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
attn_out = self.flash_attn_varlen_func( attn_out = self.flash_attn_varlen_func(
q=q, q=q,
...@@ -737,7 +744,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -737,7 +744,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T if self.use_llama_nn and isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod):
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
...@@ -971,4 +981,4 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -971,4 +981,4 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded return output_padded
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment