Commit 1e622f10 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent 31a3beb5
......@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
......
......@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
......
......@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment