Unverified Commit 199bb01d authored by yuhsaun-t's avatar yuhsaun-t Committed by GitHub
Browse files

Add endpoints to dump selected expert ids (#4435)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 6b7038ba
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
"- `/update_weights`\n", "- `/update_weights`\n",
"- `/encode`(embedding model)\n", "- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n", "- `/classify`(reward model)\n",
"- `/start_expert_distribution_record`\n",
"- `/stop_expert_distribution_record`\n",
"- `/dump_expert_distribution_record`\n",
"\n", "\n",
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`." "We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
] ]
...@@ -362,6 +365,67 @@ ...@@ -362,6 +365,67 @@
"terminate_process(reward_process)" "terminate_process(reward_process)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Capture expert selection distribution in MoE models\n",
"\n",
"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"Content of dumped record:\")\n",
" for line in f:\n",
" print_highlight(line.strip())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(expert_record_server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -343,6 +343,36 @@ async def stop_profile_async(): ...@@ -343,6 +343,36 @@ async def stop_profile_async():
) )
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any."""
_global_state.tokenizer_manager.start_expert_distribution_record()
return Response(
content="Start recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
async def stop_expert_distribution_record_async():
"""Stop recording the expert distribution."""
_global_state.tokenizer_manager.stop_expert_distribution_record()
return Response(
content="Stop recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
async def dump_expert_distribution_record_async():
"""Dump expert distribution record."""
_global_state.tokenizer_manager.dump_expert_distribution_record()
return Response(
content="Dump expert distribution record.\n",
status_code=200,
)
@app.post("/update_weights_from_disk") @app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server.""" """Update the weights from disk inplace without re-launching the server."""
......
...@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda ...@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
from sglang.srt.managers.utils import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native( def fused_topk_native(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -223,4 +227,6 @@ def select_experts( ...@@ -223,4 +227,6 @@ def select_experts(
renormalize=renormalize, renormalize=renormalize,
) )
expert_distribution_recorder.record_new_token(topk_ids)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -658,6 +658,12 @@ class ProfileReqType(Enum): ...@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
STOP_PROFILE = 2 STOP_PROFILE = 2
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass @dataclass
class ProfileReq: class ProfileReq:
type: ProfileReqType type: ProfileReqType
......
...@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput ...@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
ExpertDistributionReq,
FlushCacheReq, FlushCacheReq,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
...@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( ...@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
...@@ -128,6 +129,8 @@ from sglang.srt.utils import ( ...@@ -128,6 +129,8 @@ from sglang.srt.utils import (
) )
from sglang.utils import TypeBasedDispatcher, get_exception_traceback from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes # Test retract decode for debugging purposes
...@@ -403,6 +406,7 @@ class Scheduler( ...@@ -403,6 +406,7 @@ class Scheduler(
(GetInternalStateReq, self.get_internal_state), (GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state), (SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request), (RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
] ]
) )
...@@ -1892,6 +1896,16 @@ class Scheduler( ...@@ -1892,6 +1896,16 @@ class Scheduler(
ProfileReqOutput(success=True, message="Succeeded.") ProfileReqOutput(success=True, message="Succeeded.")
) )
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
def open_session(self, recv_req: OpenSessionReqInput): def open_session(self, recv_req: OpenSessionReqInput):
# handle error # handle error
session_id = recv_req.session_id session_id = recv_req.session_id
......
...@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
ExpertDistributionReq,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq, GetInternalStateReq,
...@@ -638,6 +639,18 @@ class TokenizerManager: ...@@ -638,6 +639,18 @@ class TokenizerManager:
req = ProfileReq(type=ProfileReqType.STOP_PROFILE) req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
def start_expert_distribution_record(self):
req = ExpertDistributionReq.START_RECORD
self.send_to_scheduler.send_pyobj(req)
def stop_expert_distribution_record(self):
req = ExpertDistributionReq.STOP_RECORD
self.send_to_scheduler.send_pyobj(req)
def dump_expert_distribution_record(self):
req = ExpertDistributionReq.DUMP_RECORD
self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk( async def update_weights_from_disk(
self, self,
obj: UpdateWeightFromDiskReqInput, obj: UpdateWeightFromDiskReqInput,
......
import json
import logging import logging
import time
from collections import defaultdict
from http import HTTPStatus from http import HTTPStatus
from typing import Optional from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
...@@ -42,3 +47,75 @@ def validate_input_length( ...@@ -42,3 +47,75 @@ def validate_input_length(
return error_msg return error_msg
return None return None
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
)
self._record = False
self._current_layer_id = "UNKNOWN"
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def record_new_token(self, topk_ids):
if not self._record:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()
...@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
...@@ -80,6 +81,8 @@ if _is_cuda: ...@@ -80,6 +81,8 @@ if _is_cuda:
else: else:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
expert_distribution_recorder = ExpertDistributionRecorder()
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
......
...@@ -30,6 +30,7 @@ suites = { ...@@ -30,6 +30,7 @@ suites = {
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2), TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_expert_distribution.py", 31),
TestFile("test_gguf.py", 78), TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_gptqmodel_dynamic.py", 72),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
......
import csv
import glob
import os
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestExpertDistribution(unittest.TestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def tearDown(self):
# Clean up any expert distribution files after each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def test_expert_distribution_record(self):
"""Test expert distribution record endpoints"""
process = popen_launch_server(
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
try:
# Start recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
self.assertEqual(response.status_code, 200)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
)
# Check CSV file format
with open(csv_files[0], "r") as f:
csv_reader = csv.reader(f)
# Check header
header = next(csv_reader)
self.assertEqual(
header,
["layer_id", "expert_id", "count"],
"CSV header should be 'layer_id,expert_id,count'",
)
# Check data rows
rows = list(csv_reader)
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
for row in rows:
# Verify each row has 3 columns
self.assertEqual(
len(row),
3,
"Each row should have layer_id, expert_id and count",
)
# Verify data types
layer_id, expert_id, count = row
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
self.assertTrue(
expert_id.isdigit(), "expert_id should be an integer"
)
self.assertTrue(count.isdigit(), "count should be an integer")
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment