Unverified Commit fd08c048 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support custom DeepEP tuning config (#6257)

parent 26ebb849
import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import DeepEPMode from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config
try: try:
from deep_ep import Buffer from deep_ep import Buffer, Config
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
...@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ...@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
class DeepEPDispatchMode(IntEnum): class DeepEPDispatchMode(IntEnum):
NORMAL = auto() NORMAL = auto()
...@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum): ...@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum):
class DeepEPBuffer: class DeepEPBuffer:
_buffer = None _buffer = None
_dispatch_mode: Optional[DeepEPDispatchMode] = None _dispatch_mode: Optional[DeepEPDispatchMode] = None
_hidden_size: Optional[int] = None _hidden_size: Optional[int] = None
...@@ -60,8 +64,10 @@ class DeepEPBuffer: ...@@ -60,8 +64,10 @@ class DeepEPBuffer:
if deepep_mode.enable_normal(): if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes hidden_bytes = hidden_size * param_bytes
for config in ( for config in (
Buffer.get_dispatch_config(group.size()), _DeepEPConfig.get_instance().normal_dispatch_config
Buffer.get_combine_config(group.size()), or Buffer.get_dispatch_config(group.size()),
_DeepEPConfig.get_instance().normal_combine_config
or Buffer.get_combine_config(group.size()),
): ):
num_nvl_bytes = max( num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
...@@ -113,6 +119,28 @@ class DeepEPBuffer: ...@@ -113,6 +119,28 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPConfig:
_instance = None
def __init__(self):
config_str = global_server_args_dict["deepep_config"]
if config_str:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:
logger.info(f"Use DeepEP Config: {config_parsed}")
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
self.normal_combine_config = Config(**config_parsed["normal_combine"])
else:
self.normal_dispatch_config = None
self.normal_combine_config = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = _DeepEPConfig()
return cls._instance
class _DeepEPDispatcherImplBase: class _DeepEPDispatcherImplBase:
def __init__( def __init__(
self, self,
...@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish, async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish, allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1, expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
config=_DeepEPConfig.get_instance().normal_dispatch_config,
) )
return ( return (
...@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish, async_finish=self.async_finish,
previous_event=previous_event, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None, allocate_on_comm_stream=previous_event is not None,
config=_DeepEPConfig.get_instance().normal_combine_config,
) )
return combined_x, event return combined_x, event
......
...@@ -77,6 +77,7 @@ global_server_args_dict = { ...@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head, "enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"deepep_config": ServerArgs.deepep_config,
"enable_nan_detection": ServerArgs.enable_nan_detection, "enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size, "max_micro_batch_size": ServerArgs.max_micro_batch_size,
......
...@@ -165,6 +165,7 @@ class ModelRunner: ...@@ -165,6 +165,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe, "enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_config": server_args.deepep_config,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size, "moe_dense_tp_size": server_args.moe_dense_tp_size,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "n_share_experts_fusion": server_args.n_share_experts_fusion,
......
...@@ -169,6 +169,7 @@ class ServerArgs: ...@@ -169,6 +169,7 @@ class ServerArgs:
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
deepep_config: Optional[str] = None
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None cuda_graph_max_bs: Optional[int] = None
...@@ -1249,6 +1250,12 @@ class ServerArgs: ...@@ -1249,6 +1250,12 @@ class ServerArgs:
default="auto", default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
parser.add_argument(
"--deepep-config",
type=str,
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster.",
)
parser.add_argument( parser.add_argument(
"--n-share-experts-fusion", "--n-share-experts-fusion",
......
...@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg): ...@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg):
logger.info(msg) logger.info(msg)
def load_json_config(data: str):
try:
return json.loads(data)
except JSONDecodeError:
return json.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
import json
import os
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
...@@ -64,8 +66,34 @@ class TestDPAttn(unittest.TestCase): ...@@ -64,8 +66,34 @@ class TestDPAttn(unittest.TestCase):
"2", "2",
"--enable-dp-attention", "--enable-dp-attention",
"--enable-deepep-moe", "--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph", "--disable-cuda-graph",
# Test custom config
"--deepep-config",
json.dumps(
{
"normal_dispatch": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 16,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
"normal_combine": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 6,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
}
),
], ],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
**os.environ,
},
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment