Unverified Commit e39628fd authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

[2/2] Deepseek deterministic: support deepseek v3 deterministic inference on 8 x H200 (#12095)

parent bacb3825
...@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple ...@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_device_name, is_hip from sglang.srt.utils import get_device_name, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,6 +52,11 @@ def get_moe_configs( ...@@ -51,6 +52,11 @@ def get_moe_configs(
kernel on a given batch size bs, the closest batch size in the grid should kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel. be picked and the associated configuration chosen to invoke the kernel.
""" """
if get_global_server_args().enable_deterministic_inference:
logger.warning(
"Deterministic inference is enabled, using default MoE kernel config."
)
return None
# Supported Triton versions, should be sorted from the newest to the oldest # Supported Triton versions, should be sorted from the newest to the oldest
supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"] supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
...@@ -130,6 +136,14 @@ def get_default_config( ...@@ -130,6 +136,14 @@ def get_default_config(
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
) -> Dict[str, int]: ) -> Dict[str, int]:
if get_global_server_args().enable_deterministic_inference:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
return config
if dtype == "fp8_w8a8": if dtype == "fp8_w8a8":
if block_shape is None: if block_shape is None:
config = { config = {
......
...@@ -515,6 +515,9 @@ class MoEGate(nn.Module): ...@@ -515,6 +515,9 @@ class MoEGate(nn.Module):
True, # is_vnni True, # is_vnni
) )
if get_global_server_args().enable_deterministic_inference:
return F.linear(hidden_states, self.weight, None)
# NOTE: For some unknown reason, router_gemm seems degrade accept length. # NOTE: For some unknown reason, router_gemm seems degrade accept length.
if ( if (
_is_cuda _is_cuda
......
...@@ -193,6 +193,8 @@ class TestFusedMOE(CustomTestCase): ...@@ -193,6 +193,8 @@ class TestFusedMOE(CustomTestCase):
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True] fp8_modes = [False, True]
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
# Calculate total number of tests # Calculate total number of tests
total_tests = ( total_tests = (
len(m_values) len(m_values)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment