Unverified Commit af5a3edb authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

fix api_server `stop` and `end_session` (#835)

parent 558029b6
...@@ -57,7 +57,6 @@ class AsyncEngine: ...@@ -57,7 +57,6 @@ class AsyncEngine:
self.id2step = {} self.id2step = {}
self.id2generator = {} self.id2generator = {}
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.special_gen = self.tm_model.create_instance()
self.gens_set = set() self.gens_set = set()
for i in range(instance_num): for i in range(instance_num):
self.gens_set.add(self.tm_model.create_instance()) self.gens_set.add(self.tm_model.create_instance())
...@@ -101,8 +100,7 @@ class AsyncEngine: ...@@ -101,8 +100,7 @@ class AsyncEngine:
def stop_session(self, session_id: int): def stop_session(self, session_id: int):
"""Stop a session by a session_id.""" """Stop a session by a session_id."""
input_ids = [self.tm_model.eos_id] input_ids = [self.tm_model.eos_id]
stop_generator = self.id2generator.get(str(session_id), stop_generator = self.tm_model.create_instance()
self.special_gen)
for outputs in stop_generator.stream_infer(session_id, for outputs in stop_generator.stream_infer(session_id,
input_ids, input_ids,
request_output_len=0, request_output_len=0,
...@@ -117,8 +115,7 @@ class AsyncEngine: ...@@ -117,8 +115,7 @@ class AsyncEngine:
def end_session(self, session_id: int): def end_session(self, session_id: int):
"""Clear a session by a session_id.""" """Clear a session by a session_id."""
input_ids = [self.tm_model.eos_id] input_ids = [self.tm_model.eos_id]
end_generator = self.id2generator.get(str(session_id), end_generator = self.tm_model.create_instance()
self.special_gen)
for outputs in end_generator.stream_infer(session_id, for outputs in end_generator.stream_infer(session_id,
input_ids, input_ids,
request_output_len=0, request_output_len=0,
...@@ -151,10 +148,12 @@ class AsyncEngine: ...@@ -151,10 +148,12 @@ class AsyncEngine:
async def get_generator(self, stop: bool, session_id: int): async def get_generator(self, stop: bool, session_id: int):
"""Only return the model instance if it is available.""" """Only return the model instance if it is available."""
if stop: if stop:
return self.id2generator.get(str(session_id), self.special_gen) return self.tm_model.create_instance()
while self.gens_set == set(): while self.gens_set == set():
await asyncio.sleep(0) await asyncio.sleep(0)
return self.gens_set.pop() generator = self.gens_set.pop()
self.id2generator[str(session_id)] = generator
return generator
def batch_infer(self, def batch_infer(self,
prompts: Union[List[str], str], prompts: Union[List[str], str],
...@@ -274,7 +273,6 @@ class AsyncEngine: ...@@ -274,7 +273,6 @@ class AsyncEngine:
self.end_session(session_id) self.end_session(session_id)
else: else:
generator = await self.get_generator(stop, session_id) generator = await self.get_generator(stop, session_id)
self.id2generator[str(session_id)] = generator
with self.safe_run(session_id): with self.safe_run(session_id):
response_size = 0 response_size = 0
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment