Commit b909d6fc authored by zhuwenwen's avatar zhuwenwen
Browse files

修改flashmla的接口

parent 22a46529
...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache( ...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache(
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
None, None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment