Unverified Commit 506be6b8 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix compile_deep_gemm missing kv_b_proj (#5620)

parent 2343d8df
...@@ -30,6 +30,8 @@ multiprocessing.set_start_method("spawn", force=True) ...@@ -30,6 +30,8 @@ multiprocessing.set_start_method("spawn", force=True)
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1" os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
# Force enable deep gemm # Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -80,7 +80,15 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder ...@@ -80,7 +80,15 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import BumpAllocator, DeepEPMode, add_prefix, is_cuda, is_hip from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
add_prefix,
get_bool_env_var,
get_int_env_var,
is_cuda,
is_hip,
)
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -549,10 +557,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -549,10 +557,14 @@ class DeepseekV2AttentionMLA(nn.Module):
"disable_chunked_prefix_cache" "disable_chunked_prefix_cache"
] ]
self.attention_backend = global_server_args_dict["attention_backend"] self.attention_backend = global_server_args_dict["attention_backend"]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" self.rocm_fused_decode_mla = get_bool_env_var(
"SGLANG_ROCM_FUSED_DECODE_MLA", "false"
)
# TODO: Design a finer way to determine the threshold # TODO: Design a finer way to determine the threshold
self.chunked_prefix_cache_threshold = 8192 self.chunked_prefix_cache_threshold = get_int_env_var(
"SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
)
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment