Commit 8f25283a authored by yangql's avatar yangql
Browse files

增加marlin对cache13的支持,以及新增flash mla的kvcache fp8的支持

parent 693d5ed4
...@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( if self.kv_cache_dtype != "fp8":
"FlashMLA with FP8 KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
...@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
...@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits, num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
) )
return self._v_up_proj(o) return self._v_up_proj(o)
...@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output return output
\ No newline at end of file
...@@ -75,6 +75,8 @@ def flash_mla_with_kvcache( ...@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
...@@ -97,6 +99,22 @@ def flash_mla_with_kvcache( ...@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
"fp8_e4m3",
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, maybe_warn_marlin_atomic_add) marlin_make_workspace_new, maybe_warn_marlin_atomic_add)
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.model_executor.layers.fused_moe.fused_moe import get_moe_cache
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp: if has_zp:
return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 return scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
...@@ -104,8 +104,8 @@ def fused_marlin_moe( ...@@ -104,8 +104,8 @@ def fused_marlin_moe(
topk = topk_ids.shape[1] # 8 topk = topk_ids.shape[1] # 8
#暂时固定为16384 #暂时固定为16384
CHUNK_SIZE = 16384 #CHUNK_SIZE = 16384
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
if workspace is None: if workspace is None:
...@@ -120,18 +120,21 @@ def fused_marlin_moe( ...@@ -120,18 +120,21 @@ def fused_marlin_moe(
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
intermediate_cache2 = torch.empty( intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache13 = torch.empty(
(M * topk_ids.shape[1] * max(2 * N, K), ),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] if envs.VLLM_USE_GLOBAL_CACHE13:
intermediate_cache13 = get_moe_cache(topk, N, K, device=hidden_states.device, dtype=hidden_states.dtype)
else:
intermediate_cache13 = torch.empty(
(M * topk * max(2 * N, K), ),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = intermediate_cache13[:M * topk * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] intermediate_cache3 = intermediate_cache13[:M * topk * K]
intermediate_cache3 = intermediate_cache3.view(-1, K) intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \ use_atomic_add = hidden_states.dtype == torch.half or \
......
...@@ -1086,6 +1086,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1086,6 +1086,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata, layer._k_scale, self.kv_cache_dtype)
return output_padded return output_padded
\ No newline at end of file
...@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( if self.kv_cache_dtype != "fp8":
"FlashMLA V1 with FP8 KV cache not yet supported") raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode( def _forward_decode(
self, self,
...@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
...@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
) )
return self._v_up_proj(o) return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment