Commit 1ea9a3f0 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-dev_yql_3.18' into 'v0.15.1-dev'

x接入mla_cat算子仅在nmz和kvcache-fp8情况下生效,默认关闭,开启需要export VLLM_USE_CAT_MLA=1

See merge request dcutoolkit/deeplearing/vllm!513
parents cd8563a4 3bff7958
...@@ -296,6 +296,7 @@ if TYPE_CHECKING: ...@@ -296,6 +296,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT: bool = False VLLM_USE_FUSED_FILL_RMS_CAT: bool = False
VLLM_USE_CAT_MLA: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
VLLM_USE_PP_BALANCE = True VLLM_USE_PP_BALANCE = True
VLLM_MOE_ROUTER_CAPTURE: bool = False VLLM_MOE_ROUTER_CAPTURE: bool = False
...@@ -1820,7 +1821,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1820,7 +1821,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use lightop moe_align_block_size # vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN": "VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in
("true", "1")), ("true", "1")),
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA":
lambda: (os.getenv('VLLM_USE_CAT_MLA', 'False').lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states,not triton # vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
...@@ -2355,9 +2355,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2355,9 +2355,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if fp8_attention and get_gcn_arch_name() == "gfx938": if fp8_attention and get_gcn_arch_name() == "gfx938":
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q = self._decode_concat_quant_fp8_op( if envs.VLLM_USE_CAT_MLA:
decode_ql_nope, decode_q_pe, layer._q_scale decode_q = (decode_ql_nope, decode_q_pe)
) else:
decode_q = self._decode_concat_quant_fp8_op(
decode_ql_nope, decode_q_pe, layer._q_scale
)
else: else:
decode_q = (decode_ql_nope, decode_q_pe) decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
......
...@@ -229,9 +229,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP): ...@@ -229,9 +229,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
self.quant_method = quant_config.get_name() self.quant_method = quant_config.get_name()
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
......
...@@ -28,6 +28,7 @@ from vllm.v1.attention.backend import ( ...@@ -28,6 +28,7 @@ from vllm.v1.attention.backend import (
MultipleOf, MultipleOf,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
...@@ -39,6 +40,7 @@ from vllm.v1.attention.ops.flashmla import ( ...@@ -39,6 +40,7 @@ from vllm.v1.attention.ops.flashmla import (
get_mla_metadata, get_mla_metadata,
get_mla_metadata_dense_fp8, get_mla_metadata_dense_fp8,
is_flashmla_dense_supported, is_flashmla_dense_supported,
flash_mla_with_kvcache_fp8_with_cat
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
...@@ -249,6 +251,33 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -249,6 +251,33 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if isinstance(q, tuple): if isinstance(q, tuple):
q_nope, q_pe = q q_nope, q_pe = q
if envs.VLLM_USE_CAT_MLA and self.kv_cache_dtype.startswith("fp8") and get_gcn_arch_name() == "gfx938":
assert isinstance(q_nope, torch.Tensor)
assert isinstance(q_pe, torch.Tensor)
num_decodes = attn_metadata.num_decodes
q_nope = reshape_query_for_spec_decode(q_nope, num_decodes)
q_pe = reshape_query_for_spec_decode(q_pe, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
assert q_nope.shape[0] == num_decodes
o, lse = flash_mla_with_kvcache_fp8_with_cat(
q_nope=q_nope,
q_pe=q_pe,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=scheduler_metadata.tile_scheduler_metadata,
num_splits=scheduler_metadata.num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
o = reshape_attn_output_for_spec_decode(o)
return o, lse
if envs.VLLM_USE_OPT_CAT and q_nope.shape[0] < 1024: if envs.VLLM_USE_OPT_CAT and q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import ( from vllm.v1.attention.backends.mla.test_concat import (
concat_helper_decode, concat_helper_decode,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple
import torch import torch
...@@ -211,6 +212,58 @@ def flash_mla_with_kvcache_fp8( ...@@ -211,6 +212,58 @@ def flash_mla_with_kvcache_fp8(
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
return out, softmax_lse
# #
# TODO: Add fake functions # TODO: Add fake functions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment