Unverified Commit cfa80974 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

FIFO pipe strategy for api_server (#795)

* FIFO pipe for api_server

* asyncio sleep 0

* remove unwanted import

* rename symbols

* speed benchmark up by disable preprocess for string input

* replace Queue with set

* comment
parent b8354dae
......@@ -31,54 +31,59 @@ class AsyncEngine:
tp=tp,
**kwargs)
self.tokenizer = self.tm_model.tokenizer
self.generators = [
self.tm_model.create_instance() for i in range(instance_num)
]
self.instance_num = instance_num
self.model = self.tm_model.model
self.available = [True] * instance_num
self.starts = [None] * instance_num
self.steps = {}
self.id2step = {}
self.id2generator = {}
self.loop = asyncio.get_event_loop()
self.special_gen = self.tm_model.create_instance()
self.gens_set = set()
for i in range(instance_num):
self.gens_set.add(self.tm_model.create_instance())
def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
input_ids = [self.tm_model.eos_id]
stop_generator = self.id2generator.get(str(session_id),
self.special_gen)
for outputs in stop_generator.stream_infer(session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
self.available[instance_id] = True
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
def end_session(self, session_id: int):
"""Clear a session by a session_id."""
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
input_ids = [self.tm_model.eos_id]
end_generator = self.id2generator.get(str(session_id),
self.special_gen)
for outputs in end_generator.stream_infer(session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
pass
self.steps[str(session_id)] = 0
self.available[instance_id] = True
self.id2step[str(session_id)] = 0
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
@contextmanager
def safe_run(self, instance_id: int, session_id: Optional[int] = None):
def safe_run(self, session_id: Optional[int] = None):
"""A context manager to make sure server's safe running."""
self.available[instance_id] = False
try:
yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
self.available[instance_id] = True
if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
......@@ -86,12 +91,13 @@ class AsyncEngine:
input_ids = self.tokenizer.encode(prompt)
return input_ids
async def get_generator(self, instance_id: int, stop: bool = False):
async def get_generator(self, stop: bool, session_id: int):
"""Only return the model instance if it is available."""
if not stop:
while self.available[instance_id] is False:
await asyncio.sleep(0.1)
return self.generators[instance_id]
if stop:
return self.id2generator.get(str(session_id), self.special_gen)
while self.gens_set == set():
await asyncio.sleep(0)
return self.gens_set.pop()
def batch_infer(self,
prompts: List[str],
......@@ -189,27 +195,27 @@ class AsyncEngine:
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
"""
instance_id = session_id % self.instance_num
if str(session_id) not in self.steps:
self.steps[str(session_id)] = 0
if str(session_id) not in self.id2step:
self.id2step[str(session_id)] = 0
if step != 0:
self.steps[str(session_id)] = step
self.id2step[str(session_id)] = step
seed = random.getrandbits(64)
prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
if self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason)
if sequence_end is True and sequence_start is False:
self.end_session(session_id)
else:
generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id, session_id):
generator = await self.get_generator(stop, session_id)
self.id2generator[str(session_id)] = generator
with self.safe_run(session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
......@@ -218,7 +224,7 @@ class AsyncEngine:
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=sequence_end,
step=self.steps[str(session_id)],
step=self.id2step[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
......@@ -237,16 +243,16 @@ class AsyncEngine:
continue
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens
# `response_size` might be note updated since
# ` if response.endswith('�')`
if response_size != tokens:
yield GenOut(response, self.steps[str(session_id)],
yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason)
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
self.id2step[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop:
self.steps[str(session_id)] = 0
self.id2step[str(session_id)] = 0
......@@ -136,7 +136,10 @@ async def chat_completions_v1(request: ChatCompletionRequest,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
ignore_eos=request.ignore_eos,
do_preprocess=not isinstance(request.messages,
str), # text completion for string input
)
def create_stream_response_json(
index: int,
......@@ -424,7 +427,7 @@ async def chat_interactive_v1(request: GenerateRequest,
request.session_id = random.randint(10087, 23333)
async_engine = VariableInterface.async_engine
sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0
sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0
sequence_end = not request.interactive_mode
generation = async_engine.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment