Commit 0daa00fb authored by yangql's avatar yangql
Browse files

适配在bmz上的mla的kvcache_e5m2和e4m3量化的支持

parent cb1a27d2
......@@ -215,6 +215,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
)
from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import (
......@@ -2115,7 +2116,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
scale=layer._k_scale,
)
if fp8_attention:
if fp8_attention and get_gcn_arch_name() == "gfx938":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
......@@ -2185,7 +2186,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
if fp8_attention:
if fp8_attention and get_gcn_arch_name() == "gfx938":
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
decode_q = self._decode_concat_quant_fp8_op(
......
......@@ -49,7 +49,7 @@ def sparse_attn_indexer(
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16),
((total_seq_lens, head_dim), fp8_dtype if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k.dtype,),
((total_seq_lens, 4), torch.uint8),
)
return sparse_attn_indexer_fake(
......
......@@ -121,6 +121,10 @@ def on_gfx9() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950", "gfx928", "gfx936", "gfx938"])
@cache
def get_gcn_arch_name() -> str:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return GPU_ARCH.split(':')[0]
@cache
def on_gfx942() -> bool:
......
......@@ -310,6 +310,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
kv_cache_dtype=self.kv_cache_dtype,
)
else:
o, lse = flash_mla_with_kvcache(
......
......@@ -6,7 +6,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import get_gcn_arch_name
logger = init_logger(__name__)
if current_platform.is_cuda():
......@@ -136,7 +136,7 @@ def get_mla_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
16,
# 16,
)
else:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
......@@ -158,26 +158,43 @@ def flash_mla_with_kvcache_fp8(
causal: bool = False,
descale_q: torch.Tensor | None = None,
descale_k: torch.Tensor | None = None,
kv_cache_dtype: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not _is_flashmla_available()[0]:
_raise_flashmla_unavailable()
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
if get_gcn_arch_name() == "gfx938":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
else:
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_k,
kv_cache_dtype,
)
else:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment