Unverified Commit f0653886 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Expert distribution recording without overhead for EPLB (#4957)

parent b1465557
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"expert_record_server_process, port = launch_server_cmd(\n", "expert_record_server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n", " \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat\"\n",
")\n", ")\n",
"\n", "\n",
"wait_for_server(f\"http://localhost:{port}\")" "wait_for_server(f\"http://localhost:{port}\")"
...@@ -415,19 +415,7 @@ ...@@ -415,19 +415,7 @@
"print_highlight(response)\n", "print_highlight(response)\n",
"\n", "\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n", "print_highlight(response)"
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
" print_highlight(\"|----------|-----------|--------|\")\n",
" next(f)\n",
" for i, line in enumerate(f):\n",
" if i < 9:\n",
" layer_id, expert_id, count = line.strip().split(\",\")\n",
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
] ]
}, },
{ {
......
import logging import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config from sglang.srt.utils import DeepEPMode, load_json_config
...@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ...@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
config=_DeepEPConfig.get_instance().normal_dispatch_config, config=_DeepEPConfig.get_instance().normal_dispatch_config,
) )
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert_list,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return ( return (
recv_x, recv_x,
recv_topk_idx, recv_topk_idx,
...@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
): ):
hook() if self.return_recv_hook else event.current_stream_wait() hook() if self.return_recv_hook else event.current_stream_wait()
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
masked_m
)
reorder_topk_ids = seg_indptr = None reorder_topk_ids = seg_indptr = None
return ( return (
......
...@@ -18,7 +18,10 @@ from typing import Callable, Optional ...@@ -18,7 +18,10 @@ from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
...@@ -31,8 +34,6 @@ if _is_cuda: ...@@ -31,8 +34,6 @@ if _is_cuda:
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native( def fused_topk_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -353,6 +354,6 @@ def select_experts( ...@@ -353,6 +354,6 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
expert_distribution_recorder.record_new_token(topk_ids) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return topk_weights, topk_ids return topk_weights, topk_ids
import json # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging import logging
import os
import time import time
from collections import defaultdict from abc import ABC
from typing import Dict, List, Tuple from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type
import torch import torch
import torch.distributed
from sglang.srt.managers.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
_OutputMode = Literal["file", "object"]
class ExpertDistributionRecorder(ABC):
"""Global expert distribution recording"""
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
else:
return _ExpertDistributionRecorderNoop()
@contextmanager
def with_current_layer(self, layer_idx):
yield
@contextmanager
def with_debug_name(self, debug_name):
yield
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
def on_select_experts(self, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
pass
def start_record(self):
self._on_not_implemented()
def stop_record(self):
self._on_not_implemented()
def dump_record(self, output_mode: _OutputMode = "file"):
self._on_not_implemented()
def _on_not_implemented(self):
raise Exception(
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
)
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
pass
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self): class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
# the length of the dictionary is the number of layers def __init__(
# the length of the list is the number of tokens self,
# the length of the tuple is topk's k value server_args: ServerArgs,
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( expert_location_metadata: "ExpertLocationMetadata",
list rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._recording = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
self._accumulator = _Accumulator.init_new(
server_args, expert_location_metadata, rank
) )
self._record = False self._single_pass_gatherers = {
self._current_layer_id = "UNKNOWN" k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
for k in self._accumulator.get_single_pass_gatherer_keys()
}
def with_current_layer(self, layer_idx):
return self._current_layer_idx.with_value(layer_idx)
def set_current_layer(self, layer_idx): def with_debug_name(self, debug_name):
self._current_layer_id = layer_idx return self._current_debug_name.with_value(debug_name)
def record_new_token(self, topk_ids): @contextmanager
if not self._record: def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
with self._current_forward_pass_id.with_value(forward_pass_id):
self._on_forward_pass_start(forward_batch)
try:
yield
finally:
self._on_forward_pass_end(forward_pass_id)
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() for gatherer_key, gatherer in self._single_pass_gatherers.items():
torch.cuda.synchronize() gatherer.reset()
for i in topk_ids_list: gatherer.on_forward_pass_start(forward_batch)
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self): def _on_forward_pass_end(self, forward_pass_id: int):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
single_pass_data = gatherer.collect()
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
def on_select_experts(self, topk_ids: torch.Tensor):
self._on_hook("on_select_experts", topk_ids=topk_ids)
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._on_hook(
"on_deepep_dispatch_normal",
local_physical_count_of_layer=local_physical_count_of_layer,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
self._on_hook(
"on_deepep_dispatch_low_latency",
local_physical_count_of_layer=local_physical_count_of_layer,
)
def _on_hook(self, hook_name: str, **kwargs):
if not (self._recording or torch.cuda.is_current_stream_capturing()):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
self._current_debug_name.value
)
]
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
def _reset(self):
"""Reset the expert distribution recorder.""" """Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...") logger.info("Resetting ExpertDistributionRecorder...")
self._record = False assert (
self._expert_distribution_record.clear() self._current_layer_idx.value is None
self._current_layer_id = "UNKNOWN" ), f"{self._current_layer_idx.value=}"
for gatherer in self._single_pass_gatherers.values():
gatherer.reset()
self._accumulator.reset()
def start_record(self): def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" """Start recording the expert distribution."""
if self._record == True: if self._recording:
logger.warning( logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
) )
self.reset() self._reset()
self._record = True self._recording = True
def stop_record(self): def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False.""" """Stop recording the expert distribution."""
if self._record == False: if not self._recording:
logger.warning( logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
) )
self._record = False self._recording = False
def dump_record(self): def dump_record(self, output_mode: _OutputMode = "file"):
"""Dump the expert distribution record to a file. Reset the recorder after dumping.""" """Dump the expert distribution record and reset the recorder after dumping."""
results = {} output = self._accumulator.dump(output_mode=output_mode)
for layer_idx, layer_record in self._expert_distribution_record.items(): self._reset()
results[layer_idx] = defaultdict(int) return output
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1 _global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
with open( _ExpertDistributionRecorderNoop()
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", )
"w",
) as fd:
fd.write("layer_id,expert_id,count\n") def get_global_expert_distribution_recorder():
for layer_idx, layer_results in results.items(): return _global_expert_distribution_recorder
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset() def set_global_expert_distribution_recorder(value):
global _global_expert_distribution_recorder
_global_expert_distribution_recorder = value
# --------------------------------------- SinglePassGatherer -----------------------------------------
class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank
)
if server_args.enable_deepep_moe:
if server_args.deepep_mode == "normal":
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank
)
else:
raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def on_forward_pass_start(self, forward_batch: ForwardBatch):
pass
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
pass
def reset(self):
raise NotImplementedError
def collect(self) -> Dict:
raise NotImplementedError
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._objects_of_layer = {}
def _on_layer_data(self, layer_idx: int, objects: List[int]):
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
if layer_idx in self._objects_of_layer:
self._objects_of_layer[layer_idx] = _list_sum(
self._objects_of_layer[layer_idx], objects
)
else:
self._objects_of_layer[layer_idx] = objects
def reset(self):
self._objects_of_layer.clear()
def _collect_objects(self, pad_len: int) -> torch.Tensor:
data = [
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
for layer_index in range(self._expert_location_metadata.num_layers)
]
return torch.tensor(data)
def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)]
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
# pretty slow, but we will use the DeepEP Gatherer in production
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
global_physical_count = [
0
] * self._expert_location_metadata.num_physical_experts
for token_record in topk_ids_list:
for global_physical_expert_idx in token_record:
global_physical_count[global_physical_expert_idx] += 1
self._on_layer_data(layer_idx, global_physical_count)
def collect(self) -> Dict:
global_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_physical_experts
)
return dict(global_physical_count=global_physical_count)
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
assert isinstance(local_physical_count_of_layer, list)
self._on_layer_data(layer_idx, local_physical_count_of_layer)
def collect(self) -> Dict:
local_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_local_physical_experts
)
global_physical_count = _convert_local_to_global_physical_count(
local_physical_count,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
self._expert_location_metadata.num_local_physical_experts,
),
dtype=torch.int,
device="cuda",
)
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
# Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
rank: int,
num_local_physical_experts: int,
num_physical_experts: int,
) -> torch.Tensor:
dtype = local_physical_count.dtype
device = local_physical_count.device
num_layers, _ = local_physical_count.shape
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
ans[
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
] = local_physical_count
return ans
# --------------------------------------- Accumulator -----------------------------------------
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
server_args, expert_location_metadata, rank
)
@staticmethod
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
# TODO pr-chain: enable this later
# "per_pass": _DetailAccumulator,
# "per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def get_single_pass_gatherer_keys(self):
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
pass
def reset(self):
pass
def dump(self, output_mode: _OutputMode):
pass
class _StatAccumulator(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
item_shape=(
self._expert_location_metadata.num_layers,
# Cannot use local_physical_count to support select_experts
self._expert_location_metadata.num_physical_experts,
),
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
dtype=torch.int32,
device=self._server_args.device,
)
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
# Can optimize if overhead here is large
self._global_physical_count_of_buffered_step.append(
single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
self._global_physical_count_of_buffered_step.reset()
def dump(self, output_mode: _OutputMode):
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
self._global_physical_count_of_buffered_step.get_all(),
num_layers=self._expert_location_metadata.num_layers,
num_logical_experts=self._expert_location_metadata.num_logical_experts,
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
)
output = dict(
rank=self._rank,
logical_count=logical_count_of_buffered_step,
)
if output_mode == "file":
if self._rank == 0:
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
elif output_mode == "object":
return output
else:
raise NotImplementedError
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, str(path_output))
class _Buffer:
@staticmethod
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
if buffer_size < 0:
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
else:
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
def append(self, value: torch.Tensor):
raise NotImplementedError
def get_all(self) -> torch.Tensor:
raise NotImplementedError
def reset(self):
raise NotImplementedError
class _CircularBuffer(_Buffer):
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
self._buffer = torch.zeros(
(buffer_size, *item_shape), dtype=dtype, device=device
)
self._curr_index = 0
def append(self, value: torch.Tensor):
self._buffer[self._curr_index] = value
self._curr_index = (self._curr_index + 1) % len(self._buffer)
def get_all(self) -> torch.Tensor:
return self._buffer
def reset(self):
self._buffer[...] = 0
class _InfiniteBuffer(_Buffer):
def __init__(self, item_shape: Tuple, dtype, device):
self._item_shape = item_shape
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
self._size = 0
def append(self, value: torch.Tensor):
curr_buffer_size = len(self._buffer)
dtype = self._buffer.dtype
device = self._buffer.device
if self._size == curr_buffer_size:
new_buffer = torch.zeros(
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
)
new_buffer[:curr_buffer_size] = self._buffer
self._buffer = new_buffer
self._buffer[self._size] = value
self._size += 1
def get_all(self) -> torch.Tensor:
return self._buffer[: self._size]
def reset(self):
self._buffer[...] = 0
self._size = 0
def _convert_global_physical_count_to_logical_count(
# (whatever, num_layers, num_physical_experts)
global_physical_count: torch.Tensor,
num_layers: int,
num_logical_experts: int,
physical_to_logical_map: torch.Tensor,
):
dim_extra, _, _ = global_physical_count.shape
dtype = global_physical_count.dtype
device = global_physical_count.device
logical_count = torch.zeros(
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
)
logical_count.scatter_add_(
dim=2,
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1),
src=global_physical_count,
)
return logical_count
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@property
def num_layers(self) -> int:
return self.physical_to_logical_map.shape[0]
@property
def num_physical_experts(self) -> int:
return self.physical_to_logical_map.shape[1]
@property
def num_local_physical_experts(self) -> int:
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
assert remainder == 0
return ans
@property
def num_logical_experts(self) -> int:
return self.logical_to_all_physical_map.shape[1]
@property
def ep_size(self):
# TODO change when EP size != world size
return torch.distributed.get_world_size()
def __post_init__(self):
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
self.logical_to_all_physical_map.shape
)
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
num_logical_experts = model_config_for_expert_location.num_logical_experts
physical_to_logical_map = (
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
% num_logical_experts
)
return ExpertLocationMetadata.init_by_mapping(
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
)
@staticmethod
def init_by_mapping(
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = (
ModelConfigForExpertLocation.from_model_config(model_config)
)
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
# TODO pr-chain: enable this later
# + server_args.ep_num_redundant_experts
)
ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0
num_local_physical_experts = num_physical_experts // ep_size
return dict(
model_config_for_expert_location=model_config_for_expert_location,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
ep_size=ep_size,
)
@staticmethod
def _init_raw(
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
):
_, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map_padded = F.pad(
logical_to_all_physical_map,
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
value=-1,
)
logical_to_all_physical_map_num_valid = torch.count_nonzero(
logical_to_all_physical_map != -1, dim=-1
)
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
)
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
def get_global_expert_location_metadata():
return _global_expert_location_metadata
def set_global_expert_location_metadata(value):
global _global_expert_location_metadata
assert _global_expert_location_metadata is None
_global_expert_location_metadata = value
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
layer_id, physical_expert_id
].item()
logical_to_all_physical_map[layer_id][logical_expert_id].append(
physical_expert_id
)
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
return torch.tensor(
logical_to_all_physical_map, device=physical_to_logical_map.device
)
def _pad_nested_array(arr, pad_value):
max_len = max(len(inner) for outer in arr for inner in outer)
padded = [
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
for outer in arr
]
return padded
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
if hasattr(model_class, "get_model_config_for_expert_location"):
return model_class.get_model_config_for_expert_location(
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function
if data.endswith(".pt"):
data_dict = torch.load(data, weights_only=True)
elif data.endswith(".json"):
data_dict = json.loads(Path(data).read_text())
else:
data_dict = json.loads(data)
if "physical_to_logical_map" in data_dict:
logger.info(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
)
elif "logical_count" in data_dict:
# TODO pr-chain: enable this later
raise NotImplementedError
# logger.info(
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# )
# return ExpertLocationMetadata.init_by_eplb(
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else:
raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
)
...@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -142,8 +145,6 @@ from sglang.srt.utils import ( ...@@ -142,8 +145,6 @@ from sglang.srt.utils import (
) )
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
...@@ -2162,11 +2163,11 @@ class Scheduler( ...@@ -2162,11 +2163,11 @@ class Scheduler(
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record() get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD: elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record() get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD: elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record() get_global_expert_distribution_recorder().dump_record()
else: else:
raise ValueError("Unrecognized ExpertDistributionReq value") raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput() return ExpertDistributionReqOutput()
......
...@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import ( ...@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
compute_initial_expert_location_metadata,
get_global_expert_location_metadata,
set_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool, DoubleSparseTokenToKVPool,
...@@ -161,6 +171,8 @@ class ModelRunner: ...@@ -161,6 +171,8 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Model-specific adjustment # Model-specific adjustment
self.model_specific_adjustment() self.model_specific_adjustment()
...@@ -219,6 +231,25 @@ class ModelRunner: ...@@ -219,6 +231,25 @@ class ModelRunner:
enable=self.server_args.enable_memory_saver enable=self.server_args.enable_memory_saver
) )
if not self.is_draft_worker:
set_global_expert_location_metadata(
compute_initial_expert_location_metadata(server_args, self.model_config)
)
if self.tp_rank == 0 and get_bool_env_var(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger.info(
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
)
set_global_expert_distribution_recorder(
ExpertDistributionRecorder.init_new(
server_args,
get_global_expert_location_metadata(),
rank=self.tp_rank,
)
)
# Load the model # Load the model
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
...@@ -1093,6 +1124,22 @@ class ModelRunner: ...@@ -1093,6 +1124,22 @@ class ModelRunner:
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False, skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1
with get_global_expert_distribution_recorder().with_forward_pass(
self.forward_pass_id,
forward_batch,
):
return self._forward_raw(
forward_batch, skip_attn_backend_init, pp_proxy_tensors
)
def _forward_raw(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool,
pp_proxy_tensors: Optional[PPProxyTensors],
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
......
...@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -109,8 +113,6 @@ if _is_hip: ...@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor: ) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states forward_mode, hidden_states
): ):
...@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
# Fully Connected # Fully Connected
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
...@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module): ...@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i) with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator positions, hidden_states, forward_batch, residual, zero_allocator
) )
if not forward_batch.forward_mode.is_idle(): if not forward_batch.forward_mode.is_idle():
if residual is None: if residual is None:
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
......
...@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module): ...@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
residual = pp_proxy_tensors["residual"] residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
expert_distribution_recorder.set_current_layer(i) with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
return PPProxyTensors( return PPProxyTensors(
{ {
...@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
else: else:
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None,
)
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
...@@ -170,6 +170,11 @@ class ServerArgs: ...@@ -170,6 +170,11 @@ class ServerArgs:
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
init_expert_location: str = "trivial"
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
deepep_config: Optional[str] = None deepep_config: Optional[str] = None
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
...@@ -361,6 +366,15 @@ class ServerArgs: ...@@ -361,6 +366,15 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule." "Pipeline parallelism is incompatible with overlap schedule."
) )
if self.expert_distribution_recorder_buffer_size is None:
# TODO pr-chain: enable this later
# if (x := self.eplb_rebalance_num_iterations) is not None:
# self.expert_distribution_recorder_buffer_size = x
if False:
pass
elif self.expert_distribution_recorder_mode is not None:
self.expert_distribution_recorder_buffer_size = 1000
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
...@@ -1257,6 +1271,24 @@ class ServerArgs: ...@@ -1257,6 +1271,24 @@ class ServerArgs:
default="auto", default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
) )
parser.add_argument(
"--init-expert-location",
type=str,
default=ServerArgs.init_expert_location,
help="Initial location of EP experts.",
)
parser.add_argument(
"--expert-distribution-recorder-mode",
type=str,
default=ServerArgs.expert_distribution_recorder_mode,
help="Mode of expert distribution recorder.",
)
parser.add_argument(
"--expert-distribution-recorder-buffer-size",
type=int,
default=ServerArgs.expert_distribution_recorder_buffer_size,
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
)
parser.add_argument( parser.add_argument(
"--deepep-config", "--deepep-config",
type=str, type=str,
......
...@@ -46,7 +46,19 @@ from importlib.util import find_spec ...@@ -46,7 +46,19 @@ from importlib.util import find_spec
from io import BytesIO from io import BytesIO
from multiprocessing.reduction import ForkingPickler from multiprocessing.reduction import ForkingPickler
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union,
)
import numpy as np import numpy as np
import psutil import psutil
...@@ -2126,3 +2138,25 @@ def load_json_config(data: str): ...@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
T = TypeVar("T")
class Withable(Generic[T]):
def __init__(self):
self._value: Optional[T] = None
@property
def value(self) -> T:
return self._value
@contextmanager
def with_value(self, new_value: T):
assert self._value is None
self._value = new_value
try:
yield
finally:
assert self._value is new_value
self._value = None
import csv
import glob
import os import os
import tempfile
import unittest import unittest
from pathlib import Path
import requests import requests
import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -16,108 +17,86 @@ from sglang.test.test_utils import ( ...@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
class TestExpertDistribution(CustomTestCase): class TestExpertDistribution(CustomTestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def tearDown(self):
# Clean up any expert distribution files after each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def test_expert_distribution_record(self): def test_expert_distribution_record(self):
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
for info in [
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with self.subTest(info=info):
self._execute_core(**info)
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
"""Test expert distribution record endpoints""" """Test expert distribution record endpoints"""
process = popen_launch_server( with tempfile.TemporaryDirectory() as tmp_dir:
# The feature is only implemented in deepseek_v2.py os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DEFAULT_URL_FOR_TEST, process = popen_launch_server(
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, model_path,
other_args=[ DEFAULT_URL_FOR_TEST,
"--trust-remote-code", timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
], other_args=[
) "--trust-remote-code",
"--tp-size",
try: str(tp_size),
# Start recording "--expert-distribution-recorder-mode",
response = requests.post( mode,
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record" "--disable-cuda-graph",
"--disable-overlap-schedule",
],
) )
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data try:
response = requests.post( # Start recording
f"{DEFAULT_URL_FOR_TEST}/generate", response = requests.post(
json={ f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
"text": "The capital of France is", )
"sampling_params": { self.assertEqual(response.status_code, 200)
"temperature": 0,
"max_new_tokens": 32, # Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
}, },
}, )
) self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, 200)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files),
1,
f"Expected exactly one expert distribution CSV file {csv_files=}",
)
# Check CSV file format # Stop recording
with open(csv_files[0], "r") as f: response = requests.post(
csv_reader = csv.reader(f) f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Check header # Dump the recorded data
header = next(csv_reader) response = requests.post(
self.assertEqual( f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
header,
["layer_id", "expert_id", "count"],
"CSV header should be 'layer_id,expert_id,count'",
) )
self.assertEqual(response.status_code, 200)
# Check data rows # Check data rows
rows = list(csv_reader) data = torch.load(
self.assertGreater(len(rows), 0, "CSV file should contain data rows") list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
)
for row in rows: print(f"{data=}")
# Verify each row has 3 columns
self.assertEqual(
len(row),
3,
"Each row should have layer_id, expert_id and count",
)
# Verify data types if mode in ["per_pass", "per_token"]:
layer_id, expert_id, count = row self.assertGreater(len(data), 0, "Should contain data rows")
self.assertTrue( else:
layer_id.isdigit(), logical_count = data["logical_count"]
f"layer_id should be an integer {row=} {rows=}", print(f"{logical_count.sum()=} {logical_count=}")
) self.assertTrue(logical_count.sum() > 0)
self.assertTrue(
expert_id.isdigit(),
f"expert_id should be an integer {row=} {rows=}",
)
self.assertTrue(
count.isdigit(), f"count should be an integer {row=} {rows=}"
)
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment