Unverified Commit 01d2838c authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix stop_profile does not wait for finishing (#4741)

parent e3b8a722
...@@ -321,7 +321,8 @@ class Engine(EngineBase): ...@@ -321,7 +321,8 @@ class Engine(EngineBase):
loop.run_until_complete(self.tokenizer_manager.start_profile()) loop.run_until_complete(self.tokenizer_manager.start_profile())
def stop_profile(self): def stop_profile(self):
self.tokenizer_manager.stop_profile() loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.stop_profile())
def get_server_info(self): def get_server_info(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
...@@ -355,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None): ...@@ -355,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
@app.api_route("/stop_profile", methods=["GET", "POST"]) @app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async(): async def stop_profile_async():
"""Stop profiling.""" """Stop profiling."""
_global_state.tokenizer_manager.stop_profile() await _global_state.tokenizer_manager.stop_profile()
return Response( return Response(
content="Stop profiling. This will take some time.\n", content="Stop profiling. This will take some time.\n",
status_code=200, status_code=200,
......
...@@ -1512,7 +1512,7 @@ class Scheduler( ...@@ -1512,7 +1512,7 @@ class Scheduler(
self.profiler_target_forward_ct self.profiler_target_forward_ct
and self.profiler_target_forward_ct <= self.forward_ct and self.profiler_target_forward_ct <= self.forward_ct
): ):
self.stop_profile() self.send_to_tokenizer.send_pyobj(self.stop_profile())
if self.forward_sleep_time is not None: if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s") logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
...@@ -2114,7 +2114,10 @@ class Scheduler( ...@@ -2114,7 +2114,10 @@ class Scheduler(
def stop_profile(self) -> None: def stop_profile(self) -> None:
if self.profiler_activities is None: if self.profiler_activities is None:
return return ProfileReqOutput(
success=False,
message="Profiling is not in progress. Call /start_profile first.",
)
logger.info("Stop profiling...") logger.info("Stop profiling...")
if self.torch_profiler is not None: if self.torch_profiler is not None:
...@@ -2145,10 +2148,7 @@ class Scheduler( ...@@ -2145,10 +2148,7 @@ class Scheduler(
self.torch_profiler_output_dir = None self.torch_profiler_output_dir = None
self.profiler_activities = None self.profiler_activities = None
if self.profiler_target_forward_ct: return ProfileReqOutput(success=True, message="Succeeded")
self.send_to_tokenizer.send_pyobj(
ProfileReqOutput(success=True, message="Succeeded.")
)
def expert_distribution_handle(self, recv_req: ExpertDistributionReq): def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD: if recv_req == ExpertDistributionReq.START_RECORD:
......
...@@ -295,7 +295,7 @@ class TokenizerManager: ...@@ -295,7 +295,7 @@ class TokenizerManager:
self.flush_cache_communicator = _Communicator( self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.start_profile_communicator = _Communicator( self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1) self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
...@@ -360,7 +360,7 @@ class TokenizerManager: ...@@ -360,7 +360,7 @@ class TokenizerManager:
), ),
( (
ProfileReqOutput, ProfileReqOutput,
self.start_profile_communicator.handle_recv, self.profile_communicator.handle_recv,
), ),
( (
GetInternalStateReqOutput, GetInternalStateReqOutput,
...@@ -801,7 +801,14 @@ class TokenizerManager: ...@@ -801,7 +801,14 @@ class TokenizerManager:
record_shapes=record_shapes, record_shapes=record_shapes,
profile_id=str(time.time()), profile_id=str(time.time()),
) )
result = (await self.start_profile_communicator(req))[0] return await self._execute_profile(req)
async def stop_profile(self):
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
return await self._execute_profile(req)
async def _execute_profile(self, req: ProfileReq):
result = (await self.profile_communicator(req))[0]
if not result.success: if not result.success:
raise RuntimeError(result.message) raise RuntimeError(result.message)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment