Unverified Commit 01d2838c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix stop_profile does not wait for finishing (#4741)

parent e3b8a722
......@@ -321,7 +321,8 @@ class Engine(EngineBase):
loop.run_until_complete(self.tokenizer_manager.start_profile())
def stop_profile(self):
self.tokenizer_manager.stop_profile()
loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.stop_profile())
def get_server_info(self):
loop = asyncio.get_event_loop()
......
......@@ -355,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
@app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async():
"""Stop profiling."""
_global_state.tokenizer_manager.stop_profile()
await _global_state.tokenizer_manager.stop_profile()
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
......
......@@ -1512,7 +1512,7 @@ class Scheduler(
self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct
):
self.stop_profile()
self.send_to_tokenizer.send_pyobj(self.stop_profile())
if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
......@@ -2114,7 +2114,10 @@ class Scheduler(
def stop_profile(self) -> None:
if self.profiler_activities is None:
return
return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
logger.info("Stop profiling...")
if self.torch_profiler is not None:
......@@ -2145,10 +2148,7 @@ class Scheduler(
self.torch_profiler_output_dir = None
self.profiler_activities = None
if self.profiler_target_forward_ct:
self.send_to_tokenizer.send_pyobj(
ProfileReqOutput(success=True, message="Succeeded.")
)
return ProfileReqOutput(success=True, message="Succeeded")
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
......
......@@ -295,7 +295,7 @@ class TokenizerManager:
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.start_profile_communicator = _Communicator(
self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
......@@ -360,7 +360,7 @@ class TokenizerManager:
),
(
ProfileReqOutput,
self.start_profile_communicator.handle_recv,
self.profile_communicator.handle_recv,
),
(
GetInternalStateReqOutput,
......@@ -801,7 +801,14 @@ class TokenizerManager:
record_shapes=record_shapes,
profile_id=str(time.time()),
)
result = (await self.start_profile_communicator(req))[0]
return await self._execute_profile(req)
async def stop_profile(self):
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
return await self._execute_profile(req)
async def _execute_profile(self, req: ProfileReq):
result = (await self.profile_communicator(req))[0]
if not result.success:
raise RuntimeError(result.message)
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment