Unverified Commit a2e6fa7e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix][deepseek] fix flashmla kernel selection (#25956)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 9f1c4eca
...@@ -136,7 +136,7 @@ def flash_mla_with_kvcache( ...@@ -136,7 +136,7 @@ def flash_mla_with_kvcache(
descale_k is None descale_k is None
), "descale_q and descale_k should be both None or both not None" ), "descale_q and descale_k should be both None or both not None"
if (descale_q is not None) and (descale_k is not None): if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment