"vscode:/vscode.git/clone" did not exist on "4d1e52abea8277e69de281cb23634edb723fcd85"
Unverified Commit fd08c048 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support custom DeepEP tuning config (#6257)

parent 26ebb849
import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.utils import DeepEPMode
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config
try:
from deep_ep import Buffer
from deep_ep import Buffer, Config
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
......@@ -25,6 +28,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
class DeepEPDispatchMode(IntEnum):
NORMAL = auto()
......@@ -32,7 +37,6 @@ class DeepEPDispatchMode(IntEnum):
class DeepEPBuffer:
_buffer = None
_dispatch_mode: Optional[DeepEPDispatchMode] = None
_hidden_size: Optional[int] = None
......@@ -60,8 +64,10 @@ class DeepEPBuffer:
if deepep_mode.enable_normal():
hidden_bytes = hidden_size * param_bytes
for config in (
Buffer.get_dispatch_config(group.size()),
Buffer.get_combine_config(group.size()),
_DeepEPConfig.get_instance().normal_dispatch_config
or Buffer.get_dispatch_config(group.size()),
_DeepEPConfig.get_instance().normal_combine_config
or Buffer.get_combine_config(group.size()),
):
num_nvl_bytes = max(
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
......@@ -113,6 +119,28 @@ class DeepEPBuffer:
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
class _DeepEPConfig:
_instance = None
def __init__(self):
config_str = global_server_args_dict["deepep_config"]
if config_str:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:
logger.info(f"Use DeepEP Config: {config_parsed}")
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
self.normal_combine_config = Config(**config_parsed["normal_combine"])
else:
self.normal_dispatch_config = None
self.normal_combine_config = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = _DeepEPConfig()
return cls._instance
class _DeepEPDispatcherImplBase:
def __init__(
self,
......@@ -295,6 +323,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
config=_DeepEPConfig.get_instance().normal_dispatch_config,
)
return (
......@@ -394,6 +423,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
async_finish=self.async_finish,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None,
config=_DeepEPConfig.get_instance().normal_combine_config,
)
return combined_x, event
......
......@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"deepep_config": ServerArgs.deepep_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
......
......@@ -165,6 +165,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"enable_deepep_moe": server_args.enable_deepep_moe,
"deepep_config": server_args.deepep_config,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
......
......@@ -169,6 +169,7 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
deepep_config: Optional[str] = None
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None
......@@ -1249,6 +1250,12 @@ class ServerArgs:
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
parser.add_argument(
"--deepep-config",
type=str,
default=ServerArgs.deepep_config,
help="Tuned DeepEP config suitable for your own cluster.",
)
parser.add_argument(
"--n-share-experts-fusion",
......
......@@ -2102,5 +2102,12 @@ def log_info_on_rank0(logger, msg):
logger.info(msg)
def load_json_config(data: str):
try:
return json.loads(data)
except JSONDecodeError:
return json.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
import json
import os
import unittest
from types import SimpleNamespace
......@@ -64,8 +66,34 @@ class TestDPAttn(unittest.TestCase):
"2",
"--enable-dp-attention",
"--enable-deepep-moe",
"--deepep-mode",
"normal",
"--disable-cuda-graph",
# Test custom config
"--deepep-config",
json.dumps(
{
"normal_dispatch": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 16,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
"normal_combine": {
"num_sms": 20,
"num_max_nvl_chunked_send_tokens": 6,
"num_max_nvl_chunked_recv_tokens": 256,
"num_max_rdma_chunked_send_tokens": 6,
"num_max_rdma_chunked_recv_tokens": 128,
},
}
),
],
env={
"SGL_ENABLE_JIT_DEEPGEMM": "0",
**os.environ,
},
)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment