Unverified Commit f0653886 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Expert distribution recording without overhead for EPLB (#4957)

parent b1465557
......@@ -390,7 +390,7 @@
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
......@@ -415,19 +415,7 @@
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n",
" print_highlight(\"|----------|-----------|--------|\")\n",
" next(f)\n",
" for i, line in enumerate(f):\n",
" if i < 9:\n",
" layer_id, expert_id, count = line.strip().split(\",\")\n",
" print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")"
"print_highlight(response)"
]
},
{
......
import logging
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import DeepEPMode, load_json_config
......@@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
config=_DeepEPConfig.get_instance().normal_dispatch_config,
)
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
num_recv_tokens_per_expert_list,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
return (
recv_x,
recv_topk_idx,
......@@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
):
hook() if self.return_recv_hook else event.current_stream_wait()
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
masked_m
)
reorder_topk_ids = seg_indptr = None
return (
......
......@@ -18,7 +18,10 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
......@@ -31,8 +34,6 @@ if _is_cuda:
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native(
hidden_states: torch.Tensor,
......@@ -353,6 +354,6 @@ def select_experts(
renormalize=renormalize,
)
expert_distribution_recorder.record_new_token(topk_ids)
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return topk_weights, topk_ids
import json
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import os
import time
from collections import defaultdict
from typing import Dict, List, Tuple
from abc import ABC
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type
import torch
import torch.distributed
from sglang.srt.managers.expert_location import ExpertLocationMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable
logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
_OutputMode = Literal["file", "object"]
class ExpertDistributionRecorder(ABC):
"""Global expert distribution recording"""
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
else:
return _ExpertDistributionRecorderNoop()
@contextmanager
def with_current_layer(self, layer_idx):
yield
@contextmanager
def with_debug_name(self, debug_name):
yield
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
def on_select_experts(self, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
pass
def start_record(self):
self._on_not_implemented()
def stop_record(self):
self._on_not_implemented()
def dump_record(self, output_mode: _OutputMode = "file"):
self._on_not_implemented()
def _on_not_implemented(self):
raise Exception(
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
)
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
pass
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._recording = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
self._accumulator = _Accumulator.init_new(
server_args, expert_location_metadata, rank
)
self._record = False
self._current_layer_id = "UNKNOWN"
self._single_pass_gatherers = {
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
for k in self._accumulator.get_single_pass_gatherer_keys()
}
def with_current_layer(self, layer_idx):
return self._current_layer_idx.with_value(layer_idx)
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def with_debug_name(self, debug_name):
return self._current_debug_name.with_value(debug_name)
def record_new_token(self, topk_ids):
if not self._record:
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
with self._current_forward_pass_id.with_value(forward_pass_id):
self._on_forward_pass_start(forward_batch)
try:
yield
finally:
self._on_forward_pass_end(forward_pass_id)
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
for gatherer_key, gatherer in self._single_pass_gatherers.items():
gatherer.reset()
gatherer.on_forward_pass_start(forward_batch)
def reset(self):
def _on_forward_pass_end(self, forward_pass_id: int):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
single_pass_data = gatherer.collect()
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
def on_select_experts(self, topk_ids: torch.Tensor):
self._on_hook("on_select_experts", topk_ids=topk_ids)
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._on_hook(
"on_deepep_dispatch_normal",
local_physical_count_of_layer=local_physical_count_of_layer,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
self._on_hook(
"on_deepep_dispatch_low_latency",
local_physical_count_of_layer=local_physical_count_of_layer,
)
def _on_hook(self, hook_name: str, **kwargs):
if not (self._recording or torch.cuda.is_current_stream_capturing()):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
self._current_debug_name.value
)
]
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
def _reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
logger.info("Resetting ExpertDistributionRecorder...")
assert (
self._current_layer_idx.value is None
), f"{self._current_layer_idx.value=}"
for gatherer in self._single_pass_gatherers.values():
gatherer.reset()
self._accumulator.reset()
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
"""Start recording the expert distribution."""
if self._recording:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
self._reset()
self._recording = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
"""Stop recording the expert distribution."""
if not self._recording:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()
self._recording = False
def dump_record(self, output_mode: _OutputMode = "file"):
"""Dump the expert distribution record and reset the recorder after dumping."""
output = self._accumulator.dump(output_mode=output_mode)
self._reset()
return output
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
_ExpertDistributionRecorderNoop()
)
def get_global_expert_distribution_recorder():
return _global_expert_distribution_recorder
def set_global_expert_distribution_recorder(value):
global _global_expert_distribution_recorder
_global_expert_distribution_recorder = value
# --------------------------------------- SinglePassGatherer -----------------------------------------
class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank
)
if server_args.enable_deepep_moe:
if server_args.deepep_mode == "normal":
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank
)
else:
raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def on_forward_pass_start(self, forward_batch: ForwardBatch):
pass
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
pass
def reset(self):
raise NotImplementedError
def collect(self) -> Dict:
raise NotImplementedError
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._objects_of_layer = {}
def _on_layer_data(self, layer_idx: int, objects: List[int]):
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
if layer_idx in self._objects_of_layer:
self._objects_of_layer[layer_idx] = _list_sum(
self._objects_of_layer[layer_idx], objects
)
else:
self._objects_of_layer[layer_idx] = objects
def reset(self):
self._objects_of_layer.clear()
def _collect_objects(self, pad_len: int) -> torch.Tensor:
data = [
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
for layer_index in range(self._expert_location_metadata.num_layers)
]
return torch.tensor(data)
def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)]
class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer):
# pretty slow, but we will use the DeepEP Gatherer in production
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
global_physical_count = [
0
] * self._expert_location_metadata.num_physical_experts
for token_record in topk_ids_list:
for global_physical_expert_idx in token_record:
global_physical_count[global_physical_expert_idx] += 1
self._on_layer_data(layer_idx, global_physical_count)
def collect(self) -> Dict:
global_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_physical_experts
)
return dict(global_physical_count=global_physical_count)
class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer):
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
assert isinstance(local_physical_count_of_layer, list)
self._on_layer_data(layer_idx, local_physical_count_of_layer)
def collect(self) -> Dict:
local_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_local_physical_experts
)
global_physical_count = _convert_local_to_global_physical_count(
local_physical_count,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
self._expert_location_metadata.num_local_physical_experts,
),
dtype=torch.int,
device="cuda",
)
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
# Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
rank: int,
num_local_physical_experts: int,
num_physical_experts: int,
) -> torch.Tensor:
dtype = local_physical_count.dtype
device = local_physical_count.device
num_layers, _ = local_physical_count.shape
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
ans[
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
] = local_physical_count
return ans
# --------------------------------------- Accumulator -----------------------------------------
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
server_args, expert_location_metadata, rank
)
@staticmethod
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
# TODO pr-chain: enable this later
# "per_pass": _DetailAccumulator,
# "per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def get_single_pass_gatherer_keys(self):
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
pass
def reset(self):
pass
def dump(self, output_mode: _OutputMode):
pass
class _StatAccumulator(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
item_shape=(
self._expert_location_metadata.num_layers,
# Cannot use local_physical_count to support select_experts
self._expert_location_metadata.num_physical_experts,
),
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
dtype=torch.int32,
device=self._server_args.device,
)
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
# Can optimize if overhead here is large
self._global_physical_count_of_buffered_step.append(
single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
self._global_physical_count_of_buffered_step.reset()
def dump(self, output_mode: _OutputMode):
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
self._global_physical_count_of_buffered_step.get_all(),
num_layers=self._expert_location_metadata.num_layers,
num_logical_experts=self._expert_location_metadata.num_logical_experts,
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
)
output = dict(
rank=self._rank,
logical_count=logical_count_of_buffered_step,
)
if output_mode == "file":
if self._rank == 0:
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
elif output_mode == "object":
return output
else:
raise NotImplementedError
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, str(path_output))
class _Buffer:
@staticmethod
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
if buffer_size < 0:
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
else:
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
def append(self, value: torch.Tensor):
raise NotImplementedError
def get_all(self) -> torch.Tensor:
raise NotImplementedError
def reset(self):
raise NotImplementedError
class _CircularBuffer(_Buffer):
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
self._buffer = torch.zeros(
(buffer_size, *item_shape), dtype=dtype, device=device
)
self._curr_index = 0
def append(self, value: torch.Tensor):
self._buffer[self._curr_index] = value
self._curr_index = (self._curr_index + 1) % len(self._buffer)
def get_all(self) -> torch.Tensor:
return self._buffer
def reset(self):
self._buffer[...] = 0
class _InfiniteBuffer(_Buffer):
def __init__(self, item_shape: Tuple, dtype, device):
self._item_shape = item_shape
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
self._size = 0
def append(self, value: torch.Tensor):
curr_buffer_size = len(self._buffer)
dtype = self._buffer.dtype
device = self._buffer.device
if self._size == curr_buffer_size:
new_buffer = torch.zeros(
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
)
new_buffer[:curr_buffer_size] = self._buffer
self._buffer = new_buffer
self._buffer[self._size] = value
self._size += 1
def get_all(self) -> torch.Tensor:
return self._buffer[: self._size]
def reset(self):
self._buffer[...] = 0
self._size = 0
def _convert_global_physical_count_to_logical_count(
# (whatever, num_layers, num_physical_experts)
global_physical_count: torch.Tensor,
num_layers: int,
num_logical_experts: int,
physical_to_logical_map: torch.Tensor,
):
dim_extra, _, _ = global_physical_count.shape
dtype = global_physical_count.dtype
device = global_physical_count.device
logical_count = torch.zeros(
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
)
logical_count.scatter_add_(
dim=2,
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1),
src=global_physical_count,
)
return logical_count
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@property
def num_layers(self) -> int:
return self.physical_to_logical_map.shape[0]
@property
def num_physical_experts(self) -> int:
return self.physical_to_logical_map.shape[1]
@property
def num_local_physical_experts(self) -> int:
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
assert remainder == 0
return ans
@property
def num_logical_experts(self) -> int:
return self.logical_to_all_physical_map.shape[1]
@property
def ep_size(self):
# TODO change when EP size != world size
return torch.distributed.get_world_size()
def __post_init__(self):
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
self.logical_to_all_physical_map.shape
)
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
num_logical_experts = model_config_for_expert_location.num_logical_experts
physical_to_logical_map = (
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
% num_logical_experts
)
return ExpertLocationMetadata.init_by_mapping(
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
)
@staticmethod
def init_by_mapping(
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
)
return ExpertLocationMetadata._init_raw(
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = (
ModelConfigForExpertLocation.from_model_config(model_config)
)
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
# TODO pr-chain: enable this later
# + server_args.ep_num_redundant_experts
)
ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0
num_local_physical_experts = num_physical_experts // ep_size
return dict(
model_config_for_expert_location=model_config_for_expert_location,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
ep_size=ep_size,
)
@staticmethod
def _init_raw(
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
):
_, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map_padded = F.pad(
logical_to_all_physical_map,
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
value=-1,
)
logical_to_all_physical_map_num_valid = torch.count_nonzero(
logical_to_all_physical_map != -1, dim=-1
)
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
)
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
def get_global_expert_location_metadata():
return _global_expert_location_metadata
def set_global_expert_location_metadata(value):
global _global_expert_location_metadata
assert _global_expert_location_metadata is None
_global_expert_location_metadata = value
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
layer_id, physical_expert_id
].item()
logical_to_all_physical_map[layer_id][logical_expert_id].append(
physical_expert_id
)
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
return torch.tensor(
logical_to_all_physical_map, device=physical_to_logical_map.device
)
def _pad_nested_array(arr, pad_value):
max_len = max(len(inner) for outer in arr for inner in outer)
padded = [
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
for outer in arr
]
return padded
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
if hasattr(model_class, "get_model_config_for_expert_location"):
return model_class.get_model_config_for_expert_location(
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
logger.info("init_expert_location from trivial")
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function
if data.endswith(".pt"):
data_dict = torch.load(data, weights_only=True)
elif data.endswith(".json"):
data_dict = json.loads(Path(data).read_text())
else:
data_dict = json.loads(data)
if "physical_to_logical_map" in data_dict:
logger.info(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
)
elif "logical_count" in data_dict:
# TODO pr-chain: enable this later
raise NotImplementedError
# logger.info(
# "init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
# )
# return ExpertLocationMetadata.init_by_eplb(
# server_args, model_config, logical_count=data_dict["logical_count"]
# )
else:
raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
)
......@@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
......@@ -142,8 +145,6 @@ from sglang.srt.utils import (
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
......@@ -2162,11 +2163,11 @@ class Scheduler(
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
get_global_expert_distribution_recorder().start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
get_global_expert_distribution_recorder().stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
get_global_expert_distribution_recorder().dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput()
......
......@@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import (
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
set_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import (
compute_initial_expert_location_metadata,
get_global_expert_location_metadata,
set_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool,
......@@ -161,6 +171,8 @@ class ModelRunner:
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Model-specific adjustment
self.model_specific_adjustment()
......@@ -219,6 +231,25 @@ class ModelRunner:
enable=self.server_args.enable_memory_saver
)
if not self.is_draft_worker:
set_global_expert_location_metadata(
compute_initial_expert_location_metadata(server_args, self.model_config)
)
if self.tp_rank == 0 and get_bool_env_var(
"SGLANG_LOG_EXPERT_LOCATION_METADATA"
):
logger.info(
f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}"
)
set_global_expert_distribution_recorder(
ExpertDistributionRecorder.init_new(
server_args,
get_global_expert_location_metadata(),
rank=self.tp_rank,
)
)
# Load the model
self.sampler = Sampler()
self.load_model()
......@@ -1093,6 +1124,22 @@ class ModelRunner:
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
self.forward_pass_id += 1
with get_global_expert_distribution_recorder().with_forward_pass(
self.forward_pass_id,
forward_batch,
):
return self._forward_raw(
forward_batch, skip_attn_backend_init, pp_proxy_tensors
)
def _forward_raw(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool,
pp_proxy_tensors: Optional[PPProxyTensors],
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph()
......
......@@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -109,8 +113,6 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
......@@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module):
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
forward_mode = forward_batch.forward_mode
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
forward_mode, hidden_states
):
......@@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module):
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
......@@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module):
residual = None
for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual, zero_allocator
)
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
......@@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module):
torch.cuda.empty_cache()
torch.cuda.synchronize()
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.n_routed_experts,
num_groups=config.n_group,
)
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
......
......@@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
......@@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module):
residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
......@@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module):
else:
logger.warning(f"Parameter {name} not found in params_dict")
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None,
)
EntryClass = Qwen2MoeForCausalLM
......@@ -170,6 +170,11 @@ class ServerArgs:
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
init_expert_location: str = "trivial"
expert_distribution_recorder_mode: Optional[
Literal["stat", "per_pass", "per_token"]
] = None
expert_distribution_recorder_buffer_size: Optional[int] = None
deepep_config: Optional[str] = None
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
......@@ -361,6 +366,15 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule."
)
if self.expert_distribution_recorder_buffer_size is None:
# TODO pr-chain: enable this later
# if (x := self.eplb_rebalance_num_iterations) is not None:
# self.expert_distribution_recorder_buffer_size = x
if False:
pass
elif self.expert_distribution_recorder_mode is not None:
self.expert_distribution_recorder_buffer_size = 1000
# Speculative Decoding
if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE
......@@ -1257,6 +1271,24 @@ class ServerArgs:
default="auto",
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
)
parser.add_argument(
"--init-expert-location",
type=str,
default=ServerArgs.init_expert_location,
help="Initial location of EP experts.",
)
parser.add_argument(
"--expert-distribution-recorder-mode",
type=str,
default=ServerArgs.expert_distribution_recorder_mode,
help="Mode of expert distribution recorder.",
)
parser.add_argument(
"--expert-distribution-recorder-buffer-size",
type=int,
default=ServerArgs.expert_distribution_recorder_buffer_size,
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
)
parser.add_argument(
"--deepep-config",
type=str,
......
......@@ -46,7 +46,19 @@ from importlib.util import find_spec
from io import BytesIO
from multiprocessing.reduction import ForkingPickler
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union,
)
import numpy as np
import psutil
......@@ -2126,3 +2138,25 @@ def load_json_config(data: str):
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
T = TypeVar("T")
class Withable(Generic[T]):
def __init__(self):
self._value: Optional[T] = None
@property
def value(self) -> T:
return self._value
@contextmanager
def with_value(self, new_value: T):
assert self._value is None
self._value = new_value
try:
yield
finally:
assert self._value is new_value
self._value = None
import csv
import glob
import os
import tempfile
import unittest
from pathlib import Path
import requests
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
......@@ -16,108 +17,86 @@ from sglang.test.test_utils import (
class TestExpertDistribution(CustomTestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def tearDown(self):
# Clean up any expert distribution files after each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def test_expert_distribution_record(self):
# TODO: Add tests for DeepEP gatherer (currently our CI cannot run that)
for info in [
dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"),
dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2),
# TODO enable in next PR
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"),
# dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"),
]:
with self.subTest(info=info):
self._execute_core(**info)
def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1):
"""Test expert distribution record endpoints"""
process = popen_launch_server(
# The feature is only implemented in deepseek_v2.py
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
],
)
try:
# Start recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
with tempfile.TemporaryDirectory() as tmp_dir:
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
process = popen_launch_server(
model_path,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp-size",
str(tp_size),
"--expert-distribution-recorder-mode",
mode,
"--disable-cuda-graph",
"--disable-overlap-schedule",
],
)
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
try:
# Start recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
},
)
self.assertEqual(response.status_code, 200)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files),
1,
f"Expected exactly one expert distribution CSV file {csv_files=}",
)
)
self.assertEqual(response.status_code, 200)
# Check CSV file format
with open(csv_files[0], "r") as f:
csv_reader = csv.reader(f)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Check header
header = next(csv_reader)
self.assertEqual(
header,
["layer_id", "expert_id", "count"],
"CSV header should be 'layer_id,expert_id,count'",
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Check data rows
rows = list(csv_reader)
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
for row in rows:
# Verify each row has 3 columns
self.assertEqual(
len(row),
3,
"Each row should have layer_id, expert_id and count",
)
data = torch.load(
list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True
)
print(f"{data=}")
# Verify data types
layer_id, expert_id, count = row
self.assertTrue(
layer_id.isdigit(),
f"layer_id should be an integer {row=} {rows=}",
)
self.assertTrue(
expert_id.isdigit(),
f"expert_id should be an integer {row=} {rows=}",
)
self.assertTrue(
count.isdigit(), f"count should be an integer {row=} {rows=}"
)
if mode in ["per_pass", "per_token"]:
self.assertGreater(len(data), 0, "Should contain data rows")
else:
logical_count = data["logical_count"]
print(f"{logical_count.sum()=} {logical_count=}")
self.assertTrue(logical_count.sum() > 0)
finally:
kill_process_tree(process.pid)
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment