Commit b909d6fc authored by zhuwenwen's avatar zhuwenwen
Browse files

修改flashmla的接口

parent 22a46529
......@@ -100,7 +100,7 @@ def flash_mla_with_kvcache(
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment