Unverified Commit 199bb01d authored by yuhsaun-t's avatar yuhsaun-t Committed by GitHub
Browse files

Add endpoints to dump selected expert ids (#4435)


Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 6b7038ba
......@@ -17,6 +17,9 @@
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
"- `/start_expert_distribution_record`\n",
"- `/stop_expert_distribution_record`\n",
"- `/dump_expert_distribution_record`\n",
"\n",
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
]
......@@ -362,6 +365,67 @@
"terminate_process(reward_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Capture expert selection distribution in MoE models\n",
"\n",
"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"expert_record_server_process, port = launch_server_cmd(\n",
" \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"url = f\"http://localhost:{port}/generate\"\n",
"data = {\"text\": \"What is the capital of France?\"}\n",
"\n",
"response = requests.post(url, json=data)\n",
"print_highlight(response.json())\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
"print_highlight(response)\n",
"\n",
"import glob\n",
"\n",
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
"with open(output_file, \"r\") as f:\n",
" print_highlight(\"Content of dumped record:\")\n",
" for line in f:\n",
" print_highlight(line.strip())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(expert_record_server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
......
......@@ -343,6 +343,36 @@ async def stop_profile_async():
)
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any."""
_global_state.tokenizer_manager.start_expert_distribution_record()
return Response(
content="Start recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
async def stop_expert_distribution_record_async():
"""Stop recording the expert distribution."""
_global_state.tokenizer_manager.stop_expert_distribution_record()
return Response(
content="Stop recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
async def dump_expert_distribution_record_async():
"""Dump expert distribution record."""
_global_state.tokenizer_manager.dump_expert_distribution_record()
return Response(
content="Dump expert distribution record.\n",
status_code=200,
)
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""
......
......@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
_is_cuda = is_cuda()
from sglang.srt.managers.utils import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native(
hidden_states: torch.Tensor,
......@@ -223,4 +227,6 @@ def select_experts(
renormalize=renormalize,
)
expert_distribution_recorder.record_new_token(topk_ids)
return topk_weights, topk_ids
......@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
STOP_PROFILE = 2
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass
class ProfileReq:
type: ProfileReqType
......
......@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ExpertDistributionReq,
FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
......@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
......@@ -128,6 +129,8 @@ from sglang.srt.utils import (
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
......@@ -403,6 +406,7 @@ class Scheduler(
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
]
)
......@@ -1892,6 +1896,16 @@ class Scheduler(
ProfileReqOutput(success=True, message="Succeeded.")
)
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id
......
......@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
ExpertDistributionReq,
FlushCacheReq,
GenerateReqInput,
GetInternalStateReq,
......@@ -638,6 +639,18 @@ class TokenizerManager:
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req)
def start_expert_distribution_record(self):
req = ExpertDistributionReq.START_RECORD
self.send_to_scheduler.send_pyobj(req)
def stop_expert_distribution_record(self):
req = ExpertDistributionReq.STOP_RECORD
self.send_to_scheduler.send_pyobj(req)
def dump_expert_distribution_record(self):
req = ExpertDistributionReq.DUMP_RECORD
self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
......
import json
import logging
import time
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
......@@ -42,3 +47,75 @@ def validate_input_length(
return error_msg
return None
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
)
self._record = False
self._current_layer_id = "UNKNOWN"
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def record_new_token(self, topk_ids):
if not self._record:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()
......@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
......@@ -80,6 +81,8 @@ if _is_cuda:
else:
from vllm import _custom_ops as ops
expert_distribution_recorder = ExpertDistributionRecorder()
class DeepseekV2MLP(nn.Module):
def __init__(
......@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
residual = None
for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
......
......@@ -30,6 +30,7 @@ suites = {
TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36),
TestFile("test_expert_distribution.py", 31),
TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72),
TestFile("test_hidden_states.py", 55),
......
import csv
import glob
import os
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestExpertDistribution(unittest.TestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def tearDown(self):
# Clean up any expert distribution files after each test
for f in glob.glob("expert_distribution_*.csv"):
os.remove(f)
def test_expert_distribution_record(self):
"""Test expert distribution record endpoints"""
process = popen_launch_server(
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
try:
# Start recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Make some requests to generate expert distribution data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
)
self.assertEqual(response.status_code, 200)
# Stop recording
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Dump the recorded data
response = requests.post(
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
)
self.assertEqual(response.status_code, 200)
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
)
# Check CSV file format
with open(csv_files[0], "r") as f:
csv_reader = csv.reader(f)
# Check header
header = next(csv_reader)
self.assertEqual(
header,
["layer_id", "expert_id", "count"],
"CSV header should be 'layer_id,expert_id,count'",
)
# Check data rows
rows = list(csv_reader)
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
for row in rows:
# Verify each row has 3 columns
self.assertEqual(
len(row),
3,
"Each row should have layer_id, expert_id and count",
)
# Verify data types
layer_id, expert_id, count = row
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
self.assertTrue(
expert_id.isdigit(), "expert_id should be an integer"
)
self.assertTrue(count.isdigit(), "count should be an integer")
finally:
kill_process_tree(process.pid)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment