Commit badaff2d authored by wanghl6's avatar wanghl6
Browse files

添加dspk prefill atten前FUSE_CAT_AND_CAST_FP8

parent 004a1ef4
...@@ -219,6 +219,7 @@ if TYPE_CHECKING: ...@@ -219,6 +219,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False VLLM_USE_FUSED_DTBMM: bool = False
VLLM_FUSE_CAT_AND_CAST_FP8: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1404,6 +1405,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_DTBMM": "VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")), ("true", "1")),
"VLLM_FUSE_CAT_AND_CAST_FP8":
lambda: (os.environ.get("VLLM_FUSE_CAT_AND_CAST_FP8", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1036,33 +1036,44 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
use_flash_fp8_arch = ( \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
if use_fused_fp8_op:
from lightop import op
q, k, v = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else:
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe_expanded, dim=2)
dim=2)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe_expanded), dim=-1)
dim=-1)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe_expanded), dim=-1)
dim=-1)
if use_flash_fp8_arch:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: q_descale = None
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k_descale = None
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) v_descale = None
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) if not use_fused_fp8_op:
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn) q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn) k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn) v = v.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims( self._flash_attn_varlen_diff_headdims(
q=q, q=q,
...@@ -1134,29 +1145,38 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1134,29 +1145,38 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\ use_flash_fp8_arch = ( \
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" \
and envs.VLLM_USE_FLASH_ATTN_FP8
)
use_fused_fp8_op = use_flash_fp8_arch and envs.VLLM_FUSE_CAT_AND_CAST_FP8
if use_fused_fp8_op:
from lightop import op
k_pe_expanded = k_pe.expand(k_pe.shape[0], self.num_heads, k_pe.shape[-1])
q, k, v = op.ds_fused_qkv_cast_fp8(
q,
kv_nope,
k_pe_expanded,
self.qk_nope_head_dim,
self.v_head_dim
)
else:
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), dim=2)
dim=2)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
dim=-1)
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8: if use_flash_fp8_arch:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) q_descale = None
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) k_descale = None
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device) v_descale = None
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1]) if not use_fused_fp8_op:
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn) q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn) k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn) v = v.to(torch.float8_e4m3fn)
...@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1270,7 +1290,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens] decode_q = q[:num_decode_tokens]
...@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1356,7 +1375,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
False, False,
1e-6, 1e-6,
) )
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[:num_actual_toks, ...] prefill_k_c_normed = key_normed[:num_actual_toks, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment