Commit 1e622f10 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent 31a3beb5
...@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits, num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
......
...@@ -101,6 +101,8 @@ def flash_mla_with_kvcache( ...@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -145,7 +147,6 @@ def flash_mla_with_kvcache( ...@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
None,
head_dim_v, head_dim_v,
cache_seqlens, cache_seqlens,
block_table, block_table,
...@@ -153,6 +154,8 @@ def flash_mla_with_kvcache( ...@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
is_fp8_kvcache,
indices,
) )
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
......
...@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment