Unverified Commit 759e1ddf authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

make IPv6 compatible, safe run for coroutine interrupting (#487)

* make IPv6 compatible, safe run for coroutine interrupting

* instance_id -> session_id and fix api_client.py

* update doc

* remove useless faq

* safe ip mapping

* update app.py

* remove print

* update doc
parent fbd9770a
......@@ -14,7 +14,7 @@ from lmdeploy.utils import get_logger
def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
......@@ -24,7 +24,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
......@@ -36,7 +36,7 @@ def get_streaming_response(prompt: str,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\0'):
delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
......
......@@ -22,7 +22,7 @@ from typing import Iterable, List
def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
......@@ -32,7 +32,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
......@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str,
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'):
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
......@@ -91,7 +91,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"instance_id": 1,
"session_id": 1,
"sequence_start": true,
"sequence_end": true
}'
......@@ -146,11 +146,10 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service.
3. When the request with the same `instance_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.
3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.
4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue,
- kindly provide unique instance_id values when calling the `generate` API or else your requests may be associated with client IP addresses
- additionally, setting `stream=true` enables processing multiple requests simultaneously
- kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses
5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions`
......@@ -24,7 +24,7 @@ from typing import Iterable, List
def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
......@@ -34,7 +34,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
......@@ -43,7 +43,7 @@ def get_streaming_response(prompt: str,
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\0'):
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
......@@ -93,7 +93,7 @@ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"instance_id": 1,
"session_id": 1,
"sequence_start": true,
"sequence_end": true
}'
......@@ -148,12 +148,11 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数
3. 当同一个 `instance_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`
3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`
4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数:
- 不同的 instance_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。
- 设置 `stream=true` 使模型在前向传播时可以允许其他请求进入被处理
- 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。
5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。
两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置
......
......@@ -47,13 +47,30 @@ class AsyncEngine:
self.starts = [None] * instance_num
self.steps = {}
def stop_session(self, session_id: int):
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
self.available[instance_id] = True
@contextmanager
def safe_run(self, instance_id: int, stop: bool = False):
def safe_run(self, instance_id: int, session_id: Optional[int] = None):
self.available[instance_id] = False
try:
yield
except (Exception, asyncio.CancelledError) as e: # noqa
self.stop_session(session_id)
self.available[instance_id] = True
async def get_embeddings(self, prompt):
async def get_embeddings(self, prompt, do_prerpocess=False):
if do_prerpocess:
prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt)
return input_ids
......@@ -68,7 +85,7 @@ class AsyncEngine:
async def generate(
self,
messages,
instance_id,
session_id,
stream_response=True,
sequence_start=True,
sequence_end=False,
......@@ -85,7 +102,7 @@ class AsyncEngine:
Args:
messages (str | List): chat history or prompt
instance_id (int): actually request host ip
session_id (int): the session id
stream_response (bool): whether return responses streamingly
request_output_len (int): output token nums
sequence_start (bool): indicator for starting a sequence
......@@ -102,8 +119,7 @@ class AsyncEngine:
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
session_id = instance_id
instance_id %= self.instance_num
instance_id = session_id % self.instance_num
if str(session_id) not in self.steps:
self.steps[str(session_id)] = 0
if step != 0:
......@@ -119,7 +135,7 @@ class AsyncEngine:
finish_reason)
else:
generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id):
with self.safe_run(instance_id, session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
......@@ -188,14 +204,14 @@ class AsyncEngine:
instance_id %= self.instance_num
sequence_start = False
generator = await self.get_generator(instance_id)
self.available[instance_id] = False
if renew_session: # renew a session
empty_input_ids = self.tokenizer.encode('')
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids],
request_output_len=0,
sequence_start=False,
sequence_end=True):
sequence_end=True,
stop=True):
pass
self.steps[str(session_id)] = 0
if str(session_id) not in self.steps:
......@@ -212,6 +228,7 @@ class AsyncEngine:
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason)
else:
with self.safe_run(instance_id, session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
......@@ -232,11 +249,10 @@ class AsyncEngine:
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# response, history token len, input token len, gen token len
# response, history len, input len, generation len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
self.available[instance_id] = True
......@@ -12,6 +12,7 @@ from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot
THEME = gr.themes.Soft(
......@@ -37,7 +38,7 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
......@@ -166,7 +167,7 @@ def chat_stream_restful(
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
......@@ -176,7 +177,7 @@ def chat_stream_restful(
for response, tokens, finish_reason in get_streaming_response(
instruction,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=512,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
......@@ -212,12 +213,12 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=True):
......@@ -241,11 +242,11 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response('',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=False,
......@@ -259,7 +260,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(messages,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=0,
sequence_start=True,
sequence_end=False):
......@@ -346,7 +347,7 @@ async def chat_stream_local(
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
......@@ -391,7 +392,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
......@@ -419,7 +420,7 @@ async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
......
......@@ -17,7 +17,7 @@ def get_model_list(api_url: str):
def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
session_id: int,
request_output_len: int = 512,
stream: bool = True,
sequence_start: bool = True,
......@@ -28,7 +28,7 @@ def get_streaming_response(prompt: str,
pload = {
'prompt': prompt,
'stream': stream,
'instance_id': instance_id,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
......@@ -41,7 +41,7 @@ def get_streaming_response(prompt: str,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\0'):
delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data.pop('text', '')
......@@ -62,12 +62,20 @@ def main(restful_api_url: str, session_id: int = 0):
while True:
prompt = input_prompt()
if prompt == 'exit':
for output, tokens, finish_reason in get_streaming_response(
'',
f'{restful_api_url}/generate',
session_id=session_id,
request_output_len=0,
sequence_start=(nth_round == 1),
sequence_end=True):
pass
exit(0)
else:
for output, tokens, finish_reason in get_streaming_response(
prompt,
f'{restful_api_url}/generate',
instance_id=session_id,
session_id=session_id,
request_output_len=512,
sequence_start=(nth_round == 1),
sequence_end=False):
......
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import time
from http import HTTPStatus
......@@ -7,7 +6,7 @@ from typing import AsyncGenerator, List, Optional
import fire
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
......@@ -16,8 +15,8 @@ from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest,
EmbeddingsResponse, ErrorResponse, GenerateRequest, ModelCard, ModelList,
ModelPermission, UsageInfo)
EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse,
ModelCard, ModelList, ModelPermission, UsageInfo)
os.environ['TM_LOG_LEVEL'] = 'ERROR'
......@@ -73,6 +72,16 @@ async def check_request(request) -> Optional[JSONResponse]:
return ret
def ip2id(host_ip: str):
"""Convert host ip address to session id."""
if '.' in host_ip: # IPv4
return int(host_ip.replace('.', '')[-8:])
if ':' in host_ip: # IPv6
return int(host_ip.replace(':', '')[-8:], 16)
print('Warning, could not get session id from ip, set it 0')
return 0
@app.post('/v1/chat/completions')
async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None):
......@@ -106,19 +115,18 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
instance_id = int(raw_request.client.host.replace('.', ''))
session_id = ip2id(raw_request.client.host)
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(instance_id)
request_id = str(session_id)
created_time = int(time.time())
result_generator = VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
session_id,
True, # always use stream to enable batching
request.renew_session,
request_output_len=request.max_tokens if request.max_tokens else 512,
......@@ -128,15 +136,6 @@ async def chat_completions_v1(request: ChatCompletionRequest,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
async def abort_request() -> None:
async for _ in VariableInterface.async_engine.generate_openai(
request.messages,
instance_id,
True,
request.renew_session,
stop=True):
pass
def create_stream_response_json(
index: int,
text: str,
......@@ -181,12 +180,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream',
background=background_tasks)
media_type='text/event-stream')
# Non-streaming response
final_res = None
......@@ -194,7 +189,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
VariableInterface.async_engine.stop_session(session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
......@@ -257,7 +252,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- instance_id: determine which instance will be called. If not specified
- session_id: determine which instance will be called. If not specified
with a value other than -1, using host ip directly.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
......@@ -275,13 +270,13 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
1.0 means no penalty
- ignore_eos (bool): indicator for ignoring eos
"""
if request.instance_id == -1:
instance_id = int(raw_request.client.host.replace('.', ''))
request.instance_id = instance_id
if request.session_id == -1:
session_id = ip2id(raw_request.client.host)
request.session_id = session_id
generation = VariableInterface.async_engine.generate(
request.prompt,
request.instance_id,
request.session_id,
stream_response=True, # always use stream to enable batching
sequence_start=request.sequence_start,
sequence_end=request.sequence_end,
......@@ -296,21 +291,26 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for out in generation:
ret = {
'text': out.response,
'tokens': out.generate_token_len,
'finish_reason': out.finish_reason
}
yield (json.dumps(ret) + '\0').encode('utf-8')
chunk = GenerateResponse(text=out.response,
tokens=out.generate_token_len,
finish_reason=out.finish_reason)
data = chunk.model_dump_json()
yield f'{data}\n'
if request.stream:
return StreamingResponse(stream_results())
return StreamingResponse(stream_results(),
media_type='text/event-stream')
else:
ret = {}
text = ''
tokens = 0
finish_reason = None
async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response
tokens = out.generate_token_len
finish_reason = out.finish_reason
......
......@@ -190,7 +190,7 @@ class EmbeddingsResponse(BaseModel):
class GenerateRequest(BaseModel):
"""Generate request."""
prompt: Union[str, List[Dict[str, str]]]
instance_id: int = -1
session_id: int = -1
sequence_start: bool = True
sequence_end: bool = False
stream: bool = False
......@@ -201,3 +201,10 @@ class GenerateRequest(BaseModel):
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
class GenerateResponse(BaseModel):
"""Generate response."""
text: str
tokens: int
finish_reason: Optional[Literal['stop', 'length']] = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment