Unverified Commit e39628fd authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

[2/2] Deepseek deterministic: support deepseek v3 deterministic inference on 8 x H200 (#12095)

parent bacb3825
......@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_device_name, is_hip
logger = logging.getLogger(__name__)
......@@ -51,6 +52,11 @@ def get_moe_configs(
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
if get_global_server_args().enable_deterministic_inference:
logger.warning(
"Deterministic inference is enabled, using default MoE kernel config."
)
return None
# Supported Triton versions, should be sorted from the newest to the oldest
supported_triton_versions = ["3.4.0", "3.3.1", "3.2.0", "3.1.0"]
......@@ -130,6 +136,14 @@ def get_default_config(
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
if get_global_server_args().enable_deterministic_inference:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
return config
if dtype == "fp8_w8a8":
if block_shape is None:
config = {
......
......@@ -515,6 +515,9 @@ class MoEGate(nn.Module):
True, # is_vnni
)
if get_global_server_args().enable_deterministic_inference:
return F.linear(hidden_states, self.weight, None)
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
......
......@@ -193,6 +193,8 @@ class TestFusedMOE(CustomTestCase):
dtypes = [torch.float16, torch.bfloat16]
fp8_modes = [False, True]
set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
# Calculate total number of tests
total_tests = (
len(m_values)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment