"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "edf394f1d47293b6e480bd415250057f267d874b"
Unverified Commit af46f299 authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] add pause and continue generation for async rl training (#7419)

parent 16a6b1d8
...@@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re ...@@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
return ORJSONResponse(content=response_data, status_code=200) return ORJSONResponse(content=response_data, status_code=200)
@app.post("/pause_generation")
async def pause_generation(request: Request):
"""Pause generation."""
await _global_state.tokenizer_manager.pause_generation()
return ORJSONResponse(
content={"message": "Generation paused successfully.", "status": "ok"},
status_code=200,
)
@app.post("/continue_generation")
async def continue_generation(request: Request):
"""Continue generation."""
await _global_state.tokenizer_manager.continue_generation()
return ORJSONResponse(
content={"message": "Generation continued successfully.", "status": "ok"},
status_code=200,
)
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
......
...@@ -203,6 +203,8 @@ class TokenizerManager: ...@@ -203,6 +203,8 @@ class TokenizerManager:
self.is_image_gen = self.model_config.is_image_gen self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id self.image_token_id = self.model_config.image_token_id
self._updating = False
self._cond = asyncio.Condition()
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors() import_processors()
...@@ -421,6 +423,9 @@ class TokenizerManager: ...@@ -421,6 +423,9 @@ class TokenizerManager:
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
created_time = time.time() created_time = time.time()
async with self._cond:
await self._cond.wait_for(lambda: not self._updating)
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments() obj.normalize_batch_and_arguments()
...@@ -902,6 +907,16 @@ class TokenizerManager: ...@@ -902,6 +907,16 @@ class TokenizerManager:
self.auto_create_handle_loop() self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
async def pause_generation(self):
async with self._cond:
self._updating = True
self.abort_request(abort_all=True)
async def continue_generation(self):
async with self._cond:
self._updating = False
self._cond.notify_all()
async def update_weights_from_disk( async def update_weights_from_disk(
self, self,
obj: UpdateWeightFromDiskReqInput, obj: UpdateWeightFromDiskReqInput,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment