Commit fd8e4a76 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

set VLLM_USE_PIECEWISE=0
parent 1871c26c
...@@ -160,8 +160,8 @@ def flash_mla_with_kvcache( ...@@ -160,8 +160,8 @@ def flash_mla_with_kvcache(
else: else:
if current_platform.is_rocm(): if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata, q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
num_splits, softmax_scale, causal, is_fp8_kvcache, causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
indices) indices)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
......
...@@ -1657,7 +1657,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1657,7 +1657,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
("true", "1")), ("true", "1")),
# vLLM will use piecewise # vLLM will use piecewise
"VLLM_USE_PIECEWISE": "VLLM_USE_PIECEWISE":
lambda: (os.environ.get("VLLM_USE_PIECEWISE", "True").lower() in lambda: (os.environ.get("VLLM_USE_PIECEWISE", "False").lower() in
("true", "1")), ("true", "1")),
# vllm will use encoding_dsv32.py for dpsk-v32 # vllm will use encoding_dsv32.py for dpsk-v32
"VLLM_USE_V32_ENCODE": "VLLM_USE_V32_ENCODE":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment