Unverified Commit cfa80974 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

FIFO pipe strategy for api_server (#795)

* FIFO pipe for api_server

* asyncio sleep 0

* remove unwanted import

* rename symbols

* speed benchmark up by disable preprocess for string input

* replace Queue with set

* comment
parent b8354dae
...@@ -31,54 +31,59 @@ class AsyncEngine: ...@@ -31,54 +31,59 @@ class AsyncEngine:
tp=tp, tp=tp,
**kwargs) **kwargs)
self.tokenizer = self.tm_model.tokenizer self.tokenizer = self.tm_model.tokenizer
self.generators = [
self.tm_model.create_instance() for i in range(instance_num)
]
self.instance_num = instance_num self.instance_num = instance_num
self.model = self.tm_model.model self.model = self.tm_model.model
self.available = [True] * instance_num self.id2step = {}
self.starts = [None] * instance_num self.id2generator = {}
self.steps = {}
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.special_gen = self.tm_model.create_instance()
self.gens_set = set()
for i in range(instance_num):
self.gens_set.add(self.tm_model.create_instance())
def stop_session(self, session_id: int): def stop_session(self, session_id: int):
"""Stop a session by a session_id.""" """Stop a session by a session_id."""
instance_id = session_id % self.instance_num input_ids = [self.tm_model.eos_id]
input_ids = self.tokenizer.encode('') stop_generator = self.id2generator.get(str(session_id),
for outputs in self.generators[instance_id].stream_infer( self.special_gen)
session_id, for outputs in stop_generator.stream_infer(session_id,
input_ids, input_ids,
request_output_len=0, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=False, sequence_end=False,
stop=True): stop=True):
pass pass
self.available[instance_id] = True if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
def end_session(self, session_id: int): def end_session(self, session_id: int):
"""Clear a session by a session_id.""" """Clear a session by a session_id."""
instance_id = session_id % self.instance_num input_ids = [self.tm_model.eos_id]
input_ids = self.tokenizer.encode('') end_generator = self.id2generator.get(str(session_id),
for outputs in self.generators[instance_id].stream_infer( self.special_gen)
session_id, for outputs in end_generator.stream_infer(session_id,
input_ids, input_ids,
request_output_len=0, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=True, sequence_end=True,
stop=True): stop=True):
pass pass
self.steps[str(session_id)] = 0 self.id2step[str(session_id)] = 0
self.available[instance_id] = True if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
@contextmanager @contextmanager
def safe_run(self, instance_id: int, session_id: Optional[int] = None): def safe_run(self, session_id: Optional[int] = None):
"""A context manager to make sure server's safe running.""" """A context manager to make sure server's safe running."""
self.available[instance_id] = False
try: try:
yield yield
except (Exception, asyncio.CancelledError) as e: # noqa except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id) self.stop_session(session_id)
self.available[instance_id] = True if str(session_id) in self.id2generator and self.id2generator[str(
session_id)] not in self.gens_set:
self.gens_set.add(self.id2generator[str(session_id)])
async def get_embeddings(self, prompt, do_prerpocess=False): async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess: if do_prerpocess:
...@@ -86,12 +91,13 @@ class AsyncEngine: ...@@ -86,12 +91,13 @@ class AsyncEngine:
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
return input_ids return input_ids
async def get_generator(self, instance_id: int, stop: bool = False): async def get_generator(self, stop: bool, session_id: int):
"""Only return the model instance if it is available.""" """Only return the model instance if it is available."""
if not stop: if stop:
while self.available[instance_id] is False: return self.id2generator.get(str(session_id), self.special_gen)
await asyncio.sleep(0.1) while self.gens_set == set():
return self.generators[instance_id] await asyncio.sleep(0)
return self.gens_set.pop()
def batch_infer(self, def batch_infer(self,
prompts: List[str], prompts: List[str],
...@@ -189,27 +195,27 @@ class AsyncEngine: ...@@ -189,27 +195,27 @@ class AsyncEngine:
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages. do_preprocess (bool): whether pre-process the messages.
""" """
instance_id = session_id % self.instance_num if str(session_id) not in self.id2step:
if str(session_id) not in self.steps: self.id2step[str(session_id)] = 0
self.steps[str(session_id)] = 0
if step != 0: if step != 0:
self.steps[str(session_id)] = step self.id2step[str(session_id)] = step
seed = random.getrandbits(64) seed = random.getrandbits(64)
prompt = messages prompt = messages
if do_preprocess: if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start) prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start) input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
finish_reason = 'stop' if stop else None finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len( if self.id2step[str(session_id)] + len(
input_ids) + request_output_len >= self.tm_model.session_len: input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length' finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
if sequence_end is True and sequence_start is False: if sequence_end is True and sequence_start is False:
self.end_session(session_id) self.end_session(session_id)
else: else:
generator = await self.get_generator(instance_id, stop) generator = await self.get_generator(stop, session_id)
with self.safe_run(instance_id, session_id): self.id2generator[str(session_id)] = generator
with self.safe_run(session_id):
response_size = 0 response_size = 0
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
session_id=session_id, session_id=session_id,
...@@ -218,7 +224,7 @@ class AsyncEngine: ...@@ -218,7 +224,7 @@ class AsyncEngine:
request_output_len=request_output_len, request_output_len=request_output_len,
sequence_start=(sequence_start), sequence_start=(sequence_start),
sequence_end=sequence_end, sequence_end=sequence_end,
step=self.steps[str(session_id)], step=self.id2step[str(session_id)],
stop=stop, stop=stop,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
...@@ -237,16 +243,16 @@ class AsyncEngine: ...@@ -237,16 +243,16 @@ class AsyncEngine:
continue continue
# response, history token len, # response, history token len,
# input token len, gen token len # input token len, gen token len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size = tokens response_size = tokens
# `response_size` might be note updated since # `response_size` might be note updated since
# ` if response.endswith('�')` # ` if response.endswith('�')`
if response_size != tokens: if response_size != tokens:
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.id2step[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
# update step # update step
self.steps[str(session_id)] += len(input_ids) + tokens self.id2step[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop: if sequence_end or stop:
self.steps[str(session_id)] = 0 self.id2step[str(session_id)] = 0
...@@ -136,7 +136,10 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -136,7 +136,10 @@ async def chat_completions_v1(request: ChatCompletionRequest,
top_p=request.top_p, top_p=request.top_p,
temperature=request.temperature, temperature=request.temperature,
repetition_penalty=request.repetition_penalty, repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos) ignore_eos=request.ignore_eos,
do_preprocess=not isinstance(request.messages,
str), # text completion for string input
)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
...@@ -424,7 +427,7 @@ async def chat_interactive_v1(request: GenerateRequest, ...@@ -424,7 +427,7 @@ async def chat_interactive_v1(request: GenerateRequest,
request.session_id = random.randint(10087, 23333) request.session_id = random.randint(10087, 23333)
async_engine = VariableInterface.async_engine async_engine = VariableInterface.async_engine
sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0 sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0
sequence_end = not request.interactive_mode sequence_end = not request.interactive_mode
generation = async_engine.generate( generation = async_engine.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment