"vscode:/vscode.git/clone" did not exist on "7cc2c95ad90f18f4e15005b49a3636d6ea3135b4"
Unverified Commit eb934bdf authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix test_expert_distribution failure (#4752)

parent e45ae444
......@@ -346,7 +346,7 @@ async def stop_profile_async():
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any."""
_global_state.tokenizer_manager.start_expert_distribution_record()
await _global_state.tokenizer_manager.start_expert_distribution_record()
return Response(
content="Start recording the expert distribution.\n",
status_code=200,
......@@ -356,7 +356,7 @@ async def start_expert_distribution_record_async():
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
async def stop_expert_distribution_record_async():
"""Stop recording the expert distribution."""
_global_state.tokenizer_manager.stop_expert_distribution_record()
await _global_state.tokenizer_manager.stop_expert_distribution_record()
return Response(
content="Stop recording the expert distribution.\n",
status_code=200,
......@@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async():
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
async def dump_expert_distribution_record_async():
"""Dump expert distribution record."""
_global_state.tokenizer_manager.dump_expert_distribution_record()
await _global_state.tokenizer_manager.dump_expert_distribution_record()
return Response(
content="Dump expert distribution record.\n",
status_code=200,
......
......@@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum):
DUMP_RECORD = 3
@dataclass
class ExpertDistributionReqOutput:
pass
@dataclass
class ProfileReq:
type: ProfileReqType
......
......@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
......@@ -1905,6 +1906,7 @@ class Scheduler(
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
......
......@@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import (
ConfigureLoggingReq,
EmbeddingReqInput,
ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReq,
GenerateReqInput,
GetInternalStateReq,
......@@ -264,6 +265,9 @@ class TokenizerManager:
self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
......@@ -313,6 +317,10 @@ class TokenizerManager:
GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv,
),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
......@@ -639,17 +647,14 @@ class TokenizerManager:
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req)
def start_expert_distribution_record(self):
req = ExpertDistributionReq.START_RECORD
self.send_to_scheduler.send_pyobj(req)
async def start_expert_distribution_record(self):
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
def stop_expert_distribution_record(self):
req = ExpertDistributionReq.STOP_RECORD
self.send_to_scheduler.send_pyobj(req)
async def stop_expert_distribution_record(self):
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
def dump_expert_distribution_record(self):
req = ExpertDistributionReq.DUMP_RECORD
self.send_to_scheduler.send_pyobj(req)
async def dump_expert_distribution_record(self):
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def update_weights_from_disk(
self,
......
......@@ -28,9 +28,13 @@ class TestExpertDistribution(unittest.TestCase):
def test_expert_distribution_record(self):
"""Test expert distribution record endpoints"""
process = popen_launch_server(
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
# The feature is only implemented in deepseek_v2.py
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
],
)
try:
......@@ -68,7 +72,9 @@ class TestExpertDistribution(unittest.TestCase):
# Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual(
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
len(csv_files),
1,
f"Expected exactly one expert distribution CSV file {csv_files=}",
)
# Check CSV file format
......@@ -97,11 +103,17 @@ class TestExpertDistribution(unittest.TestCase):
# Verify data types
layer_id, expert_id, count = row
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
self.assertTrue(
expert_id.isdigit(), "expert_id should be an integer"
layer_id.isdigit(),
f"layer_id should be an integer {row=} {rows=}",
)
self.assertTrue(
expert_id.isdigit(),
f"expert_id should be an integer {row=} {rows=}",
)
self.assertTrue(
count.isdigit(), f"count should be an integer {row=} {rows=}"
)
self.assertTrue(count.isdigit(), "count should be an integer")
finally:
kill_process_tree(process.pid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment