"docs/source/vscode:/vscode.git/clone" did not exist on "d6eecf90a1fc258de3c494209ea89141c2f4bfbe"
Unverified Commit eb934bdf authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix test_expert_distribution failure (#4752)

parent e45ae444
...@@ -346,7 +346,7 @@ async def stop_profile_async(): ...@@ -346,7 +346,7 @@ async def stop_profile_async():
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"]) @app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async(): async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any.""" """Start recording the expert distribution. Clear the previous record if any."""
_global_state.tokenizer_manager.start_expert_distribution_record() await _global_state.tokenizer_manager.start_expert_distribution_record()
return Response( return Response(
content="Start recording the expert distribution.\n", content="Start recording the expert distribution.\n",
status_code=200, status_code=200,
...@@ -356,7 +356,7 @@ async def start_expert_distribution_record_async(): ...@@ -356,7 +356,7 @@ async def start_expert_distribution_record_async():
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"]) @app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
async def stop_expert_distribution_record_async(): async def stop_expert_distribution_record_async():
"""Stop recording the expert distribution.""" """Stop recording the expert distribution."""
_global_state.tokenizer_manager.stop_expert_distribution_record() await _global_state.tokenizer_manager.stop_expert_distribution_record()
return Response( return Response(
content="Stop recording the expert distribution.\n", content="Stop recording the expert distribution.\n",
status_code=200, status_code=200,
...@@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async(): ...@@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async():
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"]) @app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
async def dump_expert_distribution_record_async(): async def dump_expert_distribution_record_async():
"""Dump expert distribution record.""" """Dump expert distribution record."""
_global_state.tokenizer_manager.dump_expert_distribution_record() await _global_state.tokenizer_manager.dump_expert_distribution_record()
return Response( return Response(
content="Dump expert distribution record.\n", content="Dump expert distribution record.\n",
status_code=200, status_code=200,
......
...@@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum): ...@@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum):
DUMP_RECORD = 3 DUMP_RECORD = 3
@dataclass
class ExpertDistributionReqOutput:
pass
@dataclass @dataclass
class ProfileReq: class ProfileReq:
type: ProfileReqType type: ProfileReqType
......
...@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReq, FlushCacheReq,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
...@@ -1905,6 +1906,7 @@ class Scheduler( ...@@ -1905,6 +1906,7 @@ class Scheduler(
expert_distribution_recorder.dump_record() expert_distribution_recorder.dump_record()
else: else:
raise ValueError("Unrecognized ExpertDistributionReq value") raise ValueError("Unrecognized ExpertDistributionReq value")
return ExpertDistributionReqOutput()
def open_session(self, recv_req: OpenSessionReqInput): def open_session(self, recv_req: OpenSessionReqInput):
# handle error # handle error
......
...@@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import (
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq, GetInternalStateReq,
...@@ -264,6 +265,9 @@ class TokenizerManager: ...@@ -264,6 +265,9 @@ class TokenizerManager:
self.get_internal_state_communicator = _Communicator( self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
...@@ -313,6 +317,10 @@ class TokenizerManager: ...@@ -313,6 +317,10 @@ class TokenizerManager:
GetInternalStateReqOutput, GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv, self.get_internal_state_communicator.handle_recv,
), ),
(
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None), (HealthCheckOutput, lambda x: None),
] ]
) )
...@@ -639,17 +647,14 @@ class TokenizerManager: ...@@ -639,17 +647,14 @@ class TokenizerManager:
req = ProfileReq(type=ProfileReqType.STOP_PROFILE) req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
def start_expert_distribution_record(self): async def start_expert_distribution_record(self):
req = ExpertDistributionReq.START_RECORD await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
self.send_to_scheduler.send_pyobj(req)
def stop_expert_distribution_record(self): async def stop_expert_distribution_record(self):
req = ExpertDistributionReq.STOP_RECORD await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
self.send_to_scheduler.send_pyobj(req)
def dump_expert_distribution_record(self): async def dump_expert_distribution_record(self):
req = ExpertDistributionReq.DUMP_RECORD await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk( async def update_weights_from_disk(
self, self,
......
...@@ -28,9 +28,13 @@ class TestExpertDistribution(unittest.TestCase): ...@@ -28,9 +28,13 @@ class TestExpertDistribution(unittest.TestCase):
def test_expert_distribution_record(self): def test_expert_distribution_record(self):
"""Test expert distribution record endpoints""" """Test expert distribution record endpoints"""
process = popen_launch_server( process = popen_launch_server(
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, # The feature is only implemented in deepseek_v2.py
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
],
) )
try: try:
...@@ -68,7 +72,9 @@ class TestExpertDistribution(unittest.TestCase): ...@@ -68,7 +72,9 @@ class TestExpertDistribution(unittest.TestCase):
# Verify the dumped file exists and has correct format # Verify the dumped file exists and has correct format
csv_files = glob.glob("expert_distribution_*.csv") csv_files = glob.glob("expert_distribution_*.csv")
self.assertEqual( self.assertEqual(
len(csv_files), 1, "Expected exactly one expert distribution CSV file" len(csv_files),
1,
f"Expected exactly one expert distribution CSV file {csv_files=}",
) )
# Check CSV file format # Check CSV file format
...@@ -97,11 +103,17 @@ class TestExpertDistribution(unittest.TestCase): ...@@ -97,11 +103,17 @@ class TestExpertDistribution(unittest.TestCase):
# Verify data types # Verify data types
layer_id, expert_id, count = row layer_id, expert_id, count = row
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
self.assertTrue( self.assertTrue(
expert_id.isdigit(), "expert_id should be an integer" layer_id.isdigit(),
f"layer_id should be an integer {row=} {rows=}",
)
self.assertTrue(
expert_id.isdigit(),
f"expert_id should be an integer {row=} {rows=}",
)
self.assertTrue(
count.isdigit(), f"count should be an integer {row=} {rows=}"
) )
self.assertTrue(count.isdigit(), "count should be an integer")
finally: finally:
kill_process_tree(process.pid) kill_process_tree(process.pid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment