Unverified Commit f0653886 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Expert distribution recording without overhead for EPLB (#4957)

parent b1465557
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"expert_record_server_process, port = launch_server_cmd(\n", "expert_record_server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n", " \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(f\"http://localhost:{port}\")" "wait_for_server(f\"http://localhost:{port}\")"
...@@ -415,19 +415,7 @@ ...@@ -415,19 +415,7 @@
"print_highlight(response)\n", "print_highlight(response)\n",
"\n", "\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n", "print_highlight(response)"
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
" print_highlight(\"|----------|-----------|--------|\")\n",
" next(f)\n",
" for i, line in enumerate(f):\n",
" if i < 9:\n",
" layer_id, expert_id, count = line.strip().split(\",\")\n",
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
] ]
}, },
{ {
......
import logging import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config from sglang.srt.utils import DeepEPMode, load_json_config
...@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
config=_DeepEPConfig.get_instance().normal_dispatch_config, config=_DeepEPConfig.get_instance().normal_dispatch_config,
) )
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert_list,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return ( return (
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
...@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
): ):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
masked_m
)
reorder_topk_ids = seg_indptr = None reorder_topk_ids = seg_indptr = None
return ( return (
......
...@@ -18,7 +18,10 @@ from typing import Callable, Optional ...@@ -18,7 +18,10 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
...@@ -31,8 +34,6 @@ if _is_cuda: ...@@ -31,8 +34,6 @@ if _is_cuda:
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native( def fused_topk_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -353,6 +354,6 @@ def select_experts( ...@@ -353,6 +354,6 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
expert_distribution_recorder.record_new_token(topk_ids) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return topk_weights, topk_ids return topk_weights, topk_ids
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@property
def num_layers(self) -> int:
return self.physical_to_logical_map.shape[0]
@property
def num_physical_experts(self) -> int:
return self.physical_to_logical_map.shape[1]
@property
def num_local_physical_experts(self) -> int:
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
assert remainder == 0
return ans
@property
def num_logical_experts(self) -> int:
return self.logical_to_all_physical_map.shape[1]
@property
def ep_size(self):
# TODO change when EP size != world size
return torch.distributed.get_world_size()
def __post_init__(self):
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
self.logical_to_all_physical_map.shape
)
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
num_logical_experts = model_config_for_expert_location.num_logical_experts
physical_to_logical_map = (
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
% num_logical_experts
)
return ExpertLocationMetadata.init_by_mapping(
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
)
@staticmethod
def init_by_mapping(
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = (
ModelConfigForExpertLocation.from_model_config(model_config)
)
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
# TODO pr-chain: enable this later
# + server_args.ep_num_redundant_experts
)
ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0
num_local_physical_experts = num_physical_experts // ep_size
return dict(
model_config_for_expert_location=model_config_for_expert_location,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
ep_size=ep_size,
)
@staticmethod
def _init_raw(
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
):
_, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map_padded = F.pad(
logical_to_all_physical_map,
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
value=-1,
)
logical_to_all_physical_map_num_valid = torch.count_nonzero(
logical_to_all_physical_map != -1, dim=-1
)
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
)
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
def get_global_expert_location_metadata():
return _global_expert_location_metadata
def set_global_expert_location_metadata(value):
global _global_expert_location_metadata
assert _global_expert_location_metadata is None
_global_expert_location_metadata = value
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
layer_id, physical_expert_id
].item()
logical_to_all_physical_map[layer_id][logical_expert_id].append(
physical_expert_id
)
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
return torch.tensor(
logical_to_all_physical_map, device=physical_to_logical_map.device
)
def _pad_nested_array(arr, pad_value):
max_len = max(len(inner) for outer in arr for inner in outer)
padded = [
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
for outer in arr
]
return padded
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
if hasattr(model_class, "get_model_config_for_expert_location"):
return model_class.get_model_config_for_expert_location(
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function
if data.endswith(".pt"):
data_dict = torch.load(data, weights_only=True)
elif data.endswith(".json"):
data_dict = json.loads(Path(data).read_text())
else:
data_dict = json.loads(data)
if "physical_to_logical_map" in data_dict:
logger.info(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
)
elif "logical_count" in data_dict:
# TODO pr-chain: enable this later
raise NotImplementedError
# logger.info(
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# )
# return ExpertLocationMetadata.init_by_eplb(
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else:
raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
)
...@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -142,8 +145,6 @@ from sglang.srt.utils import ( ...@@ -142,8 +145,6 @@ from sglang.srt.utils import (
) )
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
...@@ -2162,11 +2163,11 @@ class Scheduler( ...@@ -2162,11 +2163,11 @@ class Scheduler(
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record() get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD: elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record() get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD: elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record() get_global_expert_distribution_recorder().dump_record()
else: else:
raise ValueError("Unrecognized ExpertDistributionReq value") raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput() return ExpertDistributionReqOutput()
......
...@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import ( ...@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
compute_initial_expert_location_metadata,
get_global_expert_location_metadata,
set_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool, DoubleSparseTokenToKVPool,
...@@ -161,6 +171,8 @@ class ModelRunner: ...@@ -161,6 +171,8 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Model-specific adjustment # Model-specific adjustment
self.model_specific_adjustment() self.model_specific_adjustment()
...@@ -219,6 +231,25 @@ class ModelRunner: ...@@ -219,6 +231,25 @@ class ModelRunner:
enable=self.server_args.enable_memory_saver enable=self.server_args.enable_memory_saver
) )
if not self.is_draft_worker:
set_global_expert_location_metadata(
compute_initial_expert_location_metadata(server_args, self.model_config)
)
if self.tp_rank == 0 and get_bool_env_var(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger.info(
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
)
set_global_expert_distribution_recorder(
ExpertDistributionRecorder.init_new(
server_args,
get_global_expert_location_metadata(),
rank=self.tp_rank,
)
)
# Load the model # Load the model
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
...@@ -1093,6 +1124,22 @@ class ModelRunner: ...@@ -1093,6 +1124,22 @@ class ModelRunner:
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False, skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1
with get_global_expert_distribution_recorder().with_forward_pass(
self.forward_pass_id,
forward_batch,
):
return self._forward_raw(
forward_batch, skip_attn_backend_init, pp_proxy_tensors
)
def _forward_raw(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool,
pp_proxy_tensors: Optional[PPProxyTensors],
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
......
...@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -109,8 +113,6 @@ if _is_hip: ...@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor: ) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states forward_mode, hidden_states
): ):
...@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
...@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module): ...@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i) with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator positions, hidden_states, forward_batch, residual, zero_allocator
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is None: if residual is None:
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
......
...@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module): ...@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
residual = pp_proxy_tensors["residual"] residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
expert_distribution_recorder.set_current_layer(i) with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return PPProxyTensors( return PPProxyTensors(
{ {
...@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
else: else:
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None,
)
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
...@@ -170,6 +170,11 @@ class ServerArgs: ...@@ -170,6 +170,11 @@ class ServerArgs:
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
init_expert_location: str = "trivial"
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
deepep_config: Optional[str] = None deepep_config: Optional[str] = None
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
...@@ -361,6 +366,15 @@ class ServerArgs: ...@@ -361,6 +366,15 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule." "Pipeline parallelism is incompatible with overlap schedule."
) )
if self.expert_distribution_recorder_buffer_size is None:
# TODO pr-chain: enable this later
# if (x := self.eplb_rebalance_num_iterations) is not None:
# self.expert_distribution_recorder_buffer_size = x
if False:
pass
elif self.expert_distribution_recorder_mode is not None:
self.expert_distribution_recorder_buffer_size = 1000
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
...@@ -1257,6 +1271,24 @@ class ServerArgs: ...@@ -1257,6 +1271,24 @@ class ServerArgs:
default="auto", default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
parser.add_argument(
"--init-expert-location",
type=str,
default=ServerArgs.init_expert_location,
help="Initial location of EP experts.",
)
parser.add_argument(
"--expert-distribution-recorder-mode",
type=str,
default=ServerArgs.expert_distribution_recorder_mode,
help="Mode of expert distribution recorder.",
)
parser.add_argument(
"--expert-distribution-recorder-buffer-size",
type=int,
default=ServerArgs.expert_distribution_recorder_buffer_size,
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
)
parser.add_argument( parser.add_argument(
"--deepep-config", "--deepep-config",
type=str, type=str,
......
...@@ -46,7 +46,19 @@ from importlib.util import find_spec ...@@ -46,7 +46,19 @@ from importlib.util import find_spec
from io import BytesIO from io import BytesIO
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union,
)
import numpy as np import numpy as np
import psutil import psutil
...@@ -2126,3 +2138,25 @@ def load_json_config(data: str): ...@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
T = TypeVar("T")
class Withable(Generic[T]):
def __init__(self):
self._value: Optional[T] = None
@property
def value(self) -> T:
return self._value
@contextmanager
def with_value(self, new_value: T):
assert self._value is None
self._value = new_value
try:
yield
finally:
assert self._value is new_value
self._value = None
import csv
import glob
import os import os
import tempfile
import unittest import unittest
from pathlib import Path
import requests import requests
import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -16,108 +17,86 @@ from sglang.test.test_utils import ( ...@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
class TestExpertDistribution(CustomTestCase): class TestExpertDistribution(CustomTestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def tearDown(self):
# Clean up any expert distribution files after each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def test_expert_distribution_record(self): def test_expert_distribution_record(self):
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
for info in [
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with self.subTest(info=info):
self._execute_core(**info)
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
"""Test expert distribution record endpoints""" """Test expert distribution record endpoints"""
process = popen_launch_server( with tempfile.TemporaryDirectory() as tmp_dir:
# The feature is only implemented in deepseek_v2.py os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DEFAULT_URL_FOR_TEST, process = popen_launch_server(
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, model_path,
other_args=[ DEFAULT_URL_FOR_TEST,
"--trust-remote-code", timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
], other_args=[
) "--trust-remote-code",
"--tp-size",
try: str(tp_size),
# Start recording "--expert-distribution-recorder-mode",
response = requests.post( mode,
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record" "--disable-cuda-graph",
"--disable-overlap-schedule",
],
) )
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data try:
response = requests.post( # Start recording
f"{DEFAULT_URL_FOR_TEST}/generate", response = requests.post(
json={ f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
"text": "The capital of France is", )
"sampling_params": { self.assertEqual(response.status_code, 200)
"temperature": 0,
"max_new_tokens": 32, # Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
}, },
}, )
) self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, 200)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files),
1,
f"Expected exactly one expert distribution CSV file {csv_files=}",
)
# Check CSV file format # Stop recording
with open(csv_files[0], "r") as f: response = requests.post(
csv_reader = csv.reader(f) f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Check header # Dump the recorded data
header = next(csv_reader) response = requests.post(
self.assertEqual( f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
header,
["layer_id", "expert_id", "count"],
"CSV header should be 'layer_id,expert_id,count'",
) )
self.assertEqual(response.status_code, 200)
# Check data rows # Check data rows
rows = list(csv_reader) data = torch.load(
self.assertGreater(len(rows), 0, "CSV file should contain data rows") list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
)
for row in rows: print(f"{data=}")
# Verify each row has 3 columns
self.assertEqual(
len(row),
3,
"Each row should have layer_id, expert_id and count",
)
# Verify data types if mode in ["per_pass", "per_token"]:
layer_id, expert_id, count = row self.assertGreater(len(data), 0, "Should contain data rows")
self.assertTrue( else:
layer_id.isdigit(), logical_count = data["logical_count"]
f"layer_id should be an integer {row=} {rows=}", print(f"{logical_count.sum()=} {logical_count=}")
) self.assertTrue(logical_count.sum() > 0)
self.assertTrue(
expert_id.isdigit(),
f"expert_id should be an integer {row=} {rows=}",
)
self.assertTrue(
count.isdigit(), f"count should be an integer {row=} {rows=}"
)
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment