Commit 22a46529 authored by zhuwenwen's avatar zhuwenwen
Browse files

增加marlin对cache13的支持,以及新增flash mla的kvcache fp8的支持

parent e70b0ea0
......@@ -207,8 +207,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with FP8 KV cache not yet supported")
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......@@ -216,6 +217,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
......@@ -235,6 +238,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)
......@@ -1396,6 +1396,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output
\ No newline at end of file
......@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
......
......@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
......@@ -104,7 +105,8 @@ def fused_marlin_moe(
topk = topk_ids.shape[1] # 8
#暂时固定为16384
CHUNK_SIZE = 16384
#CHUNK_SIZE = 16384
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
......@@ -122,18 +124,21 @@ def fused_marlin_moe(
if global_num_experts == -1:
global_num_experts = E
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
(M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
if envs.VLLM_USE_GLOBAL_CACHE13:
intermediate_cache13 = get_moe_cache(topk, N, K, device=hidden_states.device, dtype=hidden_states.dtype)
else:
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
(M * topk * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache13[:M * topk * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache13[:M * topk * K]
intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \
......
......@@ -1325,6 +1325,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output_padded
\ No newline at end of file
......@@ -145,8 +145,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......@@ -154,6 +155,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
......@@ -172,6 +175,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment