Unverified Commit 759e1ddf authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

make IPv6 compatible, safe run for coroutine interrupting (#487)

* make IPv6 compatible, safe run for coroutine interrupting

* instance_id -> session_id and fix api_client.py

* update doc

* remove useless faq

* safe ip mapping

* update app.py

* remove print

* update doc
parent fbd9770a
...@@ -14,7 +14,7 @@ from lmdeploy.utils import get_logger ...@@ -14,7 +14,7 @@ from lmdeploy.utils import get_logger
def get_streaming_response(prompt: str, def get_streaming_response(prompt: str,
api_url: str, api_url: str,
instance_id: int, session_id: int,
request_output_len: int, request_output_len: int,
stream: bool = True, stream: bool = True,
sequence_start: bool = True, sequence_start: bool = True,
...@@ -24,7 +24,7 @@ def get_streaming_response(prompt: str, ...@@ -24,7 +24,7 @@ def get_streaming_response(prompt: str,
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
'stream': stream, 'stream': stream,
'instance_id': instance_id, 'session_id': session_id,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'sequence_start': sequence_start, 'sequence_start': sequence_start,
'sequence_end': sequence_end, 'sequence_end': sequence_end,
...@@ -36,7 +36,7 @@ def get_streaming_response(prompt: str, ...@@ -36,7 +36,7 @@ def get_streaming_response(prompt: str,
stream=stream) stream=stream)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False, decode_unicode=False,
delimiter=b'\0'): delimiter=b'\n'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json.loads(chunk.decode('utf-8'))
output = data['text'] output = data['text']
......
...@@ -22,7 +22,7 @@ from typing import Iterable, List ...@@ -22,7 +22,7 @@ from typing import Iterable, List
def get_streaming_response(prompt: str, def get_streaming_response(prompt: str,
api_url: str, api_url: str,
instance_id: int, session_id: int,
request_output_len: int, request_output_len: int,
stream: bool = True, stream: bool = True,
sequence_start: bool = True, sequence_start: bool = True,
...@@ -32,7 +32,7 @@ def get_streaming_response(prompt: str, ...@@ -32,7 +32,7 @@ def get_streaming_response(prompt: str,
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
'stream': stream, 'stream': stream,
'instance_id': instance_id, 'session_id': session_id,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'sequence_start': sequence_start, 'sequence_start': sequence_start,
'sequence_end': sequence_end, 'sequence_end': sequence_end,
...@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str, ...@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str,
response = requests.post( response = requests.post(
api_url, headers=headers, json=pload, stream=stream) api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines( for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'): chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json.loads(chunk.decode('utf-8'))
output = data['text'] output = data['text']
...@@ -91,7 +91,7 @@ curl http://{server_ip}:{server_port}/generate \ ...@@ -91,7 +91,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"prompt": "Hello! How are you?", "prompt": "Hello! How are you?",
"instance_id": 1, "session_id": 1,
"sequence_start": true, "sequence_start": true,
"sequence_end": true "sequence_end": true
}' }'
...@@ -146,11 +146,10 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True ...@@ -146,11 +146,10 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service. 2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service.
3. When the request with the same `instance_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards. 3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.
4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue, 4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue,
- kindly provide unique instance_id values when calling the `generate` API or else your requests may be associated with client IP addresses - kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses
- additionally, setting `stream=true` enables processing multiple requests simultaneously
5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions` 5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions`
...@@ -24,7 +24,7 @@ from typing import Iterable, List ...@@ -24,7 +24,7 @@ from typing import Iterable, List
def get_streaming_response(prompt: str, def get_streaming_response(prompt: str,
api_url: str, api_url: str,
instance_id: int, session_id: int,
request_output_len: int, request_output_len: int,
stream: bool = True, stream: bool = True,
sequence_start: bool = True, sequence_start: bool = True,
...@@ -34,7 +34,7 @@ def get_streaming_response(prompt: str, ...@@ -34,7 +34,7 @@ def get_streaming_response(prompt: str,
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
'stream': stream, 'stream': stream,
'instance_id': instance_id, 'session_id': session_id,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'sequence_start': sequence_start, 'sequence_start': sequence_start,
'sequence_end': sequence_end, 'sequence_end': sequence_end,
...@@ -43,7 +43,7 @@ def get_streaming_response(prompt: str, ...@@ -43,7 +43,7 @@ def get_streaming_response(prompt: str,
response = requests.post( response = requests.post(
api_url, headers=headers, json=pload, stream=stream) api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines( for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'): chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json.loads(chunk.decode('utf-8'))
output = data['text'] output = data['text']
...@@ -93,7 +93,7 @@ curl http://{server_ip}:{server_port}/generate \ ...@@ -93,7 +93,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"prompt": "Hello! How are you?", "prompt": "Hello! How are you?",
"instance_id": 1, "session_id": 1,
"sequence_start": true, "sequence_start": true,
"sequence_end": true "sequence_end": true
}' }'
...@@ -148,12 +148,11 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True ...@@ -148,12 +148,11 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数 2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数
3. 当同一个 `instance_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false` 3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`
4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数: 4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数:
- 不同的 instance_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。 - 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。
- 设置 `stream=true` 使模型在前向传播时可以允许其他请求进入被处理
5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。
两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置 两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置
......
...@@ -47,13 +47,30 @@ class AsyncEngine: ...@@ -47,13 +47,30 @@ class AsyncEngine:
self.starts = [None] * instance_num self.starts = [None] * instance_num
self.steps = {} self.steps = {}
def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
self.available[instance_id] = True
@contextmanager @contextmanager
def safe_run(self, instance_id: int, stop: bool = False): def safe_run(self, instance_id: int, session_id: Optional[int] = None):
self.available[instance_id] = False self.available[instance_id] = False
try:
yield yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
self.available[instance_id] = True self.available[instance_id] = True
async def get_embeddings(self, prompt): async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
prompt = self.model.get_prompt(prompt) prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
return input_ids return input_ids
...@@ -68,7 +85,7 @@ class AsyncEngine: ...@@ -68,7 +85,7 @@ class AsyncEngine:
async def generate( async def generate(
self, self,
messages, messages,
instance_id, session_id,
stream_response=True, stream_response=True,
sequence_start=True, sequence_start=True,
sequence_end=False, sequence_end=False,
...@@ -85,7 +102,7 @@ class AsyncEngine: ...@@ -85,7 +102,7 @@ class AsyncEngine:
Args: Args:
messages (str | List): chat history or prompt messages (str | List): chat history or prompt
instance_id (int): actually request host ip session_id (int): the session id
stream_response (bool): whether return responses streamingly stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence sequence_start (bool): indicator for starting a sequence
...@@ -102,8 +119,7 @@ class AsyncEngine: ...@@ -102,8 +119,7 @@ class AsyncEngine:
1.0 means no penalty 1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos ignore_eos (bool): indicator for ignoring eos
""" """
session_id = instance_id instance_id = session_id % self.instance_num
instance_id %= self.instance_num
if str(session_id) not in self.steps: if str(session_id) not in self.steps:
self.steps[str(session_id)] = 0 self.steps[str(session_id)] = 0
if step != 0: if step != 0:
...@@ -119,7 +135,7 @@ class AsyncEngine: ...@@ -119,7 +135,7 @@ class AsyncEngine:
finish_reason) finish_reason)
else: else:
generator = await self.get_generator(instance_id, stop) generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id): with self.safe_run(instance_id, session_id):
response_size = 0 response_size = 0
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
session_id=session_id, session_id=session_id,
...@@ -188,14 +204,14 @@ class AsyncEngine: ...@@ -188,14 +204,14 @@ class AsyncEngine:
instance_id %= self.instance_num instance_id %= self.instance_num
sequence_start = False sequence_start = False
generator = await self.get_generator(instance_id) generator = await self.get_generator(instance_id)
self.available[instance_id] = False
if renew_session: # renew a session if renew_session: # renew a session
empty_input_ids = self.tokenizer.encode('') empty_input_ids = self.tokenizer.encode('')
for outputs in generator.stream_infer(session_id=session_id, for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids], input_ids=[empty_input_ids],
request_output_len=0, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=True): sequence_end=True,
stop=True):
pass pass
self.steps[str(session_id)] = 0 self.steps[str(session_id)] = 0
if str(session_id) not in self.steps: if str(session_id) not in self.steps:
...@@ -212,6 +228,7 @@ class AsyncEngine: ...@@ -212,6 +228,7 @@ class AsyncEngine:
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
else: else:
with self.safe_run(instance_id, session_id):
response_size = 0 response_size = 0
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
session_id=session_id, session_id=session_id,
...@@ -232,11 +249,10 @@ class AsyncEngine: ...@@ -232,11 +249,10 @@ class AsyncEngine:
# decode res # decode res
response = self.tokenizer.decode(res.tolist(), response = self.tokenizer.decode(res.tolist(),
offset=response_size) offset=response_size)
# response, history token len, input token len, gen token len # response, history len, input len, generation len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size = tokens response_size = tokens
# update step # update step
self.steps[str(session_id)] += len(input_ids) + tokens self.steps[str(session_id)] += len(input_ids) + tokens
self.available[instance_id] = True
...@@ -12,6 +12,7 @@ from lmdeploy.serve.async_engine import AsyncEngine ...@@ -12,6 +12,7 @@ from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.css import CSS from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.openai.api_client import (get_model_list, from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response) get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot from lmdeploy.serve.turbomind.chatbot import Chatbot
THEME = gr.themes.Soft( THEME = gr.themes.Soft(
...@@ -37,7 +38,7 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, ...@@ -37,7 +38,7 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0] instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
bot_response = llama_chatbot.stream_infer( bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}') session_id, instruction, f'{session_id}-{len(state_chatbot)}')
...@@ -166,7 +167,7 @@ def chat_stream_restful( ...@@ -166,7 +167,7 @@ def chat_stream_restful(
""" """
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = '' bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
...@@ -176,7 +177,7 @@ def chat_stream_restful( ...@@ -176,7 +177,7 @@ def chat_stream_restful(
for response, tokens, finish_reason in get_streaming_response( for response, tokens, finish_reason in get_streaming_response(
instruction, instruction,
f'{InterFace.restful_api_url}/generate', f'{InterFace.restful_api_url}/generate',
instance_id=session_id, session_id=session_id,
request_output_len=512, request_output_len=512,
sequence_start=(len(state_chatbot) == 1), sequence_start=(len(state_chatbot) == 1),
sequence_end=False): sequence_end=False):
...@@ -212,12 +213,12 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, ...@@ -212,12 +213,12 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
for response, tokens, finish_reason in get_streaming_response( for response, tokens, finish_reason in get_streaming_response(
'', '',
f'{InterFace.restful_api_url}/generate', f'{InterFace.restful_api_url}/generate',
instance_id=session_id, session_id=session_id,
request_output_len=0, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=True): sequence_end=True):
...@@ -241,11 +242,11 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, ...@@ -241,11 +242,11 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
""" """
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
for out in get_streaming_response('', for out in get_streaming_response('',
f'{InterFace.restful_api_url}/generate', f'{InterFace.restful_api_url}/generate',
instance_id=session_id, session_id=session_id,
request_output_len=0, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=False, sequence_end=False,
...@@ -259,7 +260,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, ...@@ -259,7 +260,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
messages.append(dict(role='assistant', content=qa[1])) messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(messages, for out in get_streaming_response(messages,
f'{InterFace.restful_api_url}/generate', f'{InterFace.restful_api_url}/generate',
instance_id=session_id, session_id=session_id,
request_output_len=0, request_output_len=0,
sequence_start=True, sequence_start=True,
sequence_end=False): sequence_end=False):
...@@ -346,7 +347,7 @@ async def chat_stream_local( ...@@ -346,7 +347,7 @@ async def chat_stream_local(
""" """
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = '' bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
...@@ -391,7 +392,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, ...@@ -391,7 +392,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
async for out in InterFace.async_engine.generate('', async for out in InterFace.async_engine.generate('',
session_id, session_id,
...@@ -419,7 +420,7 @@ async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, ...@@ -419,7 +420,7 @@ async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
""" """
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = ip2id(request.kwargs['client']['host'])
# end the session # end the session
async for out in InterFace.async_engine.generate('', async for out in InterFace.async_engine.generate('',
session_id, session_id,
......
...@@ -17,7 +17,7 @@ def get_model_list(api_url: str): ...@@ -17,7 +17,7 @@ def get_model_list(api_url: str):
def get_streaming_response(prompt: str, def get_streaming_response(prompt: str,
api_url: str, api_url: str,
instance_id: int, session_id: int,
request_output_len: int = 512, request_output_len: int = 512,
stream: bool = True, stream: bool = True,
sequence_start: bool = True, sequence_start: bool = True,
...@@ -28,7 +28,7 @@ def get_streaming_response(prompt: str, ...@@ -28,7 +28,7 @@ def get_streaming_response(prompt: str,
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
'stream': stream, 'stream': stream,
'instance_id': instance_id, 'session_id': session_id,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'sequence_start': sequence_start, 'sequence_start': sequence_start,
'sequence_end': sequence_end, 'sequence_end': sequence_end,
...@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str, ...@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str,
stream=stream) stream=stream)
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False, decode_unicode=False,
delimiter=b'\0'): delimiter=b'\n'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json.loads(chunk.decode('utf-8'))
output = data.pop('text', '') output = data.pop('text', '')
...@@ -62,12 +62,20 @@ def main(restful_api_url: str, session_id: int = 0): ...@@ -62,12 +62,20 @@ def main(restful_api_url: str, session_id: int = 0):
while True: while True:
prompt = input_prompt() prompt = input_prompt()
if prompt == 'exit': if prompt == 'exit':
for output, tokens, finish_reason in get_streaming_response(
'',
f'{restful_api_url}/generate',
session_id=session_id,
request_output_len=0,
sequence_start=(nth_round == 1),
sequence_end=True):
pass
exit(0) exit(0)
else: else:
for output, tokens, finish_reason in get_streaming_response( for output, tokens, finish_reason in get_streaming_response(
prompt, prompt,
f'{restful_api_url}/generate', f'{restful_api_url}/generate',
instance_id=session_id, session_id=session_id,
request_output_len=512, request_output_len=512,
sequence_start=(nth_round == 1), sequence_start=(nth_round == 1),
sequence_end=False): sequence_end=False):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import json
import os import os
import time import time
from http import HTTPStatus from http import HTTPStatus
...@@ -7,7 +6,7 @@ from typing import AsyncGenerator, List, Optional ...@@ -7,7 +6,7 @@ from typing import AsyncGenerator, List, Optional
import fire import fire
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -16,8 +15,8 @@ from lmdeploy.serve.openai.protocol import ( # noqa: E501 ...@@ -16,8 +15,8 @@ from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest,
EmbeddingsResponse, ErrorResponse, GenerateRequest, ModelCard, ModelList, EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse,
ModelPermission, UsageInfo) ModelCard, ModelList, ModelPermission, UsageInfo)
os.environ['TM_LOG_LEVEL'] = 'ERROR' os.environ['TM_LOG_LEVEL'] = 'ERROR'
...@@ -73,6 +72,16 @@ async def check_request(request) -> Optional[JSONResponse]: ...@@ -73,6 +72,16 @@ async def check_request(request) -> Optional[JSONResponse]:
return ret return ret
def ip2id(host_ip: str):
"""Convert host ip address to session id."""
if '.' in host_ip: # IPv4
return int(host_ip.replace('.', '')[-8:])
if ':' in host_ip: # IPv6
return int(host_ip.replace(':', '')[-8:], 16)
print('Warning, could not get session id from ip, set it 0')
return 0
@app.post('/v1/chat/completions') @app.post('/v1/chat/completions')
async def chat_completions_v1(request: ChatCompletionRequest, async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None): raw_request: Request = None):
...@@ -106,19 +115,18 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -106,19 +115,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty) - presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty)
""" """
instance_id = int(raw_request.client.host.replace('.', '')) session_id = ip2id(raw_request.client.host)
error_check_ret = await check_request(request) error_check_ret = await check_request(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
model_name = request.model model_name = request.model
request_id = str(instance_id) request_id = str(session_id)
created_time = int(time.time()) created_time = int(time.time())
result_generator = VariableInterface.async_engine.generate_openai( result_generator = VariableInterface.async_engine.generate_openai(
request.messages, request.messages,
instance_id, session_id,
True, # always use stream to enable batching True, # always use stream to enable batching
request.renew_session, request.renew_session,
request_output_len=request.max_tokens if request.max_tokens else 512, request_output_len=request.max_tokens if request.max_tokens else 512,
...@@ -128,15 +136,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -128,15 +136,6 @@ async def chat_completions_v1(request: ChatCompletionRequest,
repetition_penalty=request.repetition_penalty, repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos) ignore_eos=request.ignore_eos)
async def abort_request() -> None:
async for _ in VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
True,
request.renew_session,
stop=True):
pass
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
...@@ -181,12 +180,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -181,12 +180,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
# Streaming response # Streaming response
if request.stream: if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream', media_type='text/event-stream')
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res = None final_res = None
...@@ -194,7 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, ...@@ -194,7 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() VariableInterface.async_engine.stop_session(session_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected') 'Client disconnected')
final_res = res final_res = res
...@@ -257,7 +252,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -257,7 +252,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation. - prompt: the prompt to use for the generation.
- instance_id: determine which instance will be called. If not specified - session_id: determine which instance will be called. If not specified
with a value other than -1, using host ip directly. with a value other than -1, using host ip directly.
- sequence_start (bool): indicator for starting a sequence. - sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence - sequence_end (bool): indicator for ending a sequence
...@@ -275,13 +270,13 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -275,13 +270,13 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
1.0 means no penalty 1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos - ignore_eos (bool): indicator for ignoring eos
""" """
if request.instance_id == -1: if request.session_id == -1:
instance_id = int(raw_request.client.host.replace('.', '')) session_id = ip2id(raw_request.client.host)
request.instance_id = instance_id request.session_id = session_id
generation = VariableInterface.async_engine.generate( generation = VariableInterface.async_engine.generate(
request.prompt, request.prompt,
request.instance_id, request.session_id,
stream_response=True, # always use stream to enable batching stream_response=True, # always use stream to enable batching
sequence_start=request.sequence_start, sequence_start=request.sequence_start,
sequence_end=request.sequence_end, sequence_end=request.sequence_end,
...@@ -296,21 +291,26 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -296,21 +291,26 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation: async for out in generation:
ret = { chunk = GenerateResponse(text=out.response,
'text': out.response, tokens=out.generate_token_len,
'tokens': out.generate_token_len, finish_reason=out.finish_reason)
'finish_reason': out.finish_reason data = chunk.model_dump_json()
} yield f'{data}\n'
yield (json.dumps(ret) + '\0').encode('utf-8')
if request.stream: if request.stream:
return StreamingResponse(stream_results()) return StreamingResponse(stream_results(),
media_type='text/event-stream')
else: else:
ret = {} ret = {}
text = '' text = ''
tokens = 0 tokens = 0
finish_reason = None finish_reason = None
async for out in generation: async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response text += out.response
tokens = out.generate_token_len tokens = out.generate_token_len
finish_reason = out.finish_reason finish_reason = out.finish_reason
......
...@@ -190,7 +190,7 @@ class EmbeddingsResponse(BaseModel): ...@@ -190,7 +190,7 @@ class EmbeddingsResponse(BaseModel):
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
"""Generate request.""" """Generate request."""
prompt: Union[str, List[Dict[str, str]]] prompt: Union[str, List[Dict[str, str]]]
instance_id: int = -1 session_id: int = -1
sequence_start: bool = True sequence_start: bool = True
sequence_end: bool = False sequence_end: bool = False
stream: bool = False stream: bool = False
...@@ -201,3 +201,10 @@ class GenerateRequest(BaseModel): ...@@ -201,3 +201,10 @@ class GenerateRequest(BaseModel):
temperature: float = 0.8 temperature: float = 0.8
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
ignore_eos: bool = False ignore_eos: bool = False
class GenerateResponse(BaseModel):
"""Generate response."""
text: str
tokens: int
finish_reason: Optional[Literal['stop', 'length']] = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment