Unverified Commit 4279d8ca authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Enable the Gradio server to call inference services through the RESTful API (#287)



* app use async engine

* add stop logic

* app update cancel

* app support restful-api

* update doc and use the right model name

* set doc url root

* add comments

* add an example

* renew_session

* update readme.md

* resolve comments

* Update restful_api.md

* Update restful_api.md

* Update restful_api.md

---------
Co-authored-by: default avatartpoisonooo <khj.application@aliyun.com>
parent 81f29837
...@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace ...@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### Serving with Restful API
Launch inference server by:
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1
```
Then, you can communicate with it by command line,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
or webui,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
Refer to [restful_api.md](docs/en/restful_api.md) for more details.
#### Serving with Triton Inference Server #### Serving with Triton Inference Server
Launch inference server by: Launch inference server by:
......
...@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace ...@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab) ![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### 通过 Restful API 部署服务
使用下面的命令启动推理服务:
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1
```
你可以通过命令行方式与推理服务进行对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
也可以通过 WebUI 方式来对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)
#### 通过容器部署推理服务 #### 通过容器部署推理服务
使用下面的命令启动推理服务: 使用下面的命令启动推理服务:
......
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
### Launch Service ### Launch Service
```shell ```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1 python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1
``` ```
Then, the user can open the swagger UI: http://{server_name}:{server_port}/docs for the detailed api usage. Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage.
We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try
our own api which provides more arguments for users to modify. The performance is comparatively better. our own api which provides more arguments for users to modify. The performance is comparatively better.
...@@ -50,16 +50,29 @@ def get_streaming_response(prompt: str, ...@@ -50,16 +50,29 @@ def get_streaming_response(prompt: str,
for output, tokens in get_streaming_response( for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_name}:{server_port}/generate", 0, "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512): 512):
print(output, end='') print(output, end='')
``` ```
### Golang/Rust ### Java/Golang/Rust
Golang can also build a http request to use the service. You may refer May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client.
to [the blog](https://pkg.go.dev/net/http) for details to build own client. Here is an example:
Besides, Rust supports building a client in [many ways](https://blog.logrocket.com/best-rust-http-client/).
```shell
$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust
$ ls rust/*
rust/Cargo.toml rust/git_push.sh rust/README.md
rust/docs:
ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md
DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md
rust/src:
apis lib.rs models
```
### cURL ### cURL
...@@ -68,13 +81,13 @@ cURL is a tool for observing the output of the api. ...@@ -68,13 +81,13 @@ cURL is a tool for observing the output of the api.
List Models: List Models:
```bash ```bash
curl http://{server_name}:{server_port}/v1/models curl http://{server_ip}:{server_port}/v1/models
``` ```
Generate: Generate:
```bash ```bash
curl http://{server_name}:{server_port}/generate \ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -87,7 +100,7 @@ curl http://{server_name}:{server_port}/generate \ ...@@ -87,7 +100,7 @@ curl http://{server_name}:{server_port}/generate \
Chat Completions: Chat Completions:
```bash ```bash
curl http://{server_name}:{server_port}/v1/chat/completions \ curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -98,7 +111,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \ ...@@ -98,7 +111,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \
Embeddings: Embeddings:
```bash ```bash
curl http://{server_name}:{server_port}/v1/embeddings \ curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -106,6 +119,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \ ...@@ -106,6 +119,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \
}' }'
``` ```
### CLI client
There is a client script for restful api server.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
### webui
You can also test restful-api through webui.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
### FAQ ### FAQ
1. When user got `"finish_reason":"length"` which means the session is too long to be continued. 1. When user got `"finish_reason":"length"` which means the session is too long to be continued.
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
运行脚本 运行脚本
```shell ```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1 python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1
``` ```
然后用户可以打开 swagger UI: http://{server_name}:{server_port}/docs 详细查看所有的 API 及其使用方法。 然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。
我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate` 我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`
它有更好的性能,提供更多的参数让用户自定义修改。 它有更好的性能,提供更多的参数让用户自定义修改。
...@@ -52,15 +52,29 @@ def get_streaming_response(prompt: str, ...@@ -52,15 +52,29 @@ def get_streaming_response(prompt: str,
for output, tokens in get_streaming_response( for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_name}:{server_port}/generate", 0, "Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512): 512):
print(output, end='') print(output, end='')
``` ```
### Golang/Rust ### Java/Golang/Rust
Golang 也可以建立 http 请求使用启动的服务,用户可以参考[这篇博客](https://pkg.go.dev/net/http)构建自己的客户端。 可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli)`http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。
Rust 也有许多[方法](https://blog.logrocket.com/best-rust-http-client/)构建客户端,使用服务。 下面是一个使用示例:
```shell
$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust
$ ls rust/*
rust/Cargo.toml rust/git_push.sh rust/README.md
rust/docs:
ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md
DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md
rust/src:
apis lib.rs models
```
### cURL ### cURL
...@@ -69,13 +83,13 @@ cURL 也可以用于查看 API 的输出结果 ...@@ -69,13 +83,13 @@ cURL 也可以用于查看 API 的输出结果
查看模型列表: 查看模型列表:
```bash ```bash
curl http://{server_name}:{server_port}/v1/models curl http://{server_ip}:{server_port}/v1/models
``` ```
使用 generate: 使用 generate:
```bash ```bash
curl http://{server_name}:{server_port}/generate \ curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -88,7 +102,7 @@ curl http://{server_name}:{server_port}/generate \ ...@@ -88,7 +102,7 @@ curl http://{server_name}:{server_port}/generate \
Chat Completions: Chat Completions:
```bash ```bash
curl http://{server_name}:{server_port}/v1/chat/completions \ curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -99,7 +113,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \ ...@@ -99,7 +113,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \
Embeddings: Embeddings:
```bash ```bash
curl http://{server_name}:{server_port}/v1/embeddings \ curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "internlm-chat-7b", "model": "internlm-chat-7b",
...@@ -107,6 +121,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \ ...@@ -107,6 +121,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \
}' }'
``` ```
### CLI client
restful api 服务可以通过客户端测试,例如
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
### webui
也可以直接用 webui 测试使用 restful-api。
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
# server_ip 和 server_port 是用来提供 gradio ui 访问服务的
# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
### FAQ ### FAQ
1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。 1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。
......
...@@ -3,11 +3,11 @@ import asyncio ...@@ -3,11 +3,11 @@ import asyncio
import dataclasses import dataclasses
import os.path as osp import os.path as osp
import random import random
from contextlib import contextmanager
from typing import Literal, Optional from typing import Literal, Optional
from lmdeploy import turbomind as tm from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS, BaseModel from lmdeploy.model import MODELS, BaseModel
from lmdeploy.turbomind.tokenizer import Tokenizer
@dataclasses.dataclass @dataclasses.dataclass
...@@ -30,6 +30,7 @@ class AsyncEngine: ...@@ -30,6 +30,7 @@ class AsyncEngine:
""" """
def __init__(self, model_path, instance_num=32, tp=1) -> None: def __init__(self, model_path, instance_num=32, tp=1) -> None:
from lmdeploy.turbomind.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models', tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer') 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path) tokenizer = Tokenizer(tokenizer_model_path)
...@@ -46,13 +47,20 @@ class AsyncEngine: ...@@ -46,13 +47,20 @@ class AsyncEngine:
self.starts = [None] * instance_num self.starts = [None] * instance_num
self.steps = {} self.steps = {}
@contextmanager
def safe_run(self, instance_id: int, stop: bool = False):
self.available[instance_id] = False
yield
self.available[instance_id] = True
async def get_embeddings(self, prompt): async def get_embeddings(self, prompt):
prompt = self.model.get_prompt(prompt) prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
return input_ids return input_ids
async def get_generator(self, instance_id): async def get_generator(self, instance_id: int, stop: bool = False):
"""Only return the model instance if it is available.""" """Only return the model instance if it is available."""
if not stop:
while self.available[instance_id] is False: while self.available[instance_id] is False:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
return self.generators[instance_id] return self.generators[instance_id]
...@@ -104,14 +112,14 @@ class AsyncEngine: ...@@ -104,14 +112,14 @@ class AsyncEngine:
prompt = self.model.messages2prompt(messages, sequence_start) prompt = self.model.messages2prompt(messages, sequence_start)
input_ids = self.tokenizer.encode(prompt) input_ids = self.tokenizer.encode(prompt)
finish_reason = 'stop' if stop else None finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len( if not sequence_end and self.steps[str(session_id)] + len(
input_ids) >= self.tm_model.session_len: input_ids) >= self.tm_model.session_len:
finish_reason = 'length' finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason) finish_reason)
else: else:
generator = await self.get_generator(instance_id) generator = await self.get_generator(instance_id, stop)
self.available[instance_id] = False with self.safe_run(instance_id):
response_size = 0 response_size = 0
async for outputs in generator.async_stream_infer( async for outputs in generator.async_stream_infer(
session_id=session_id, session_id=session_id,
...@@ -130,17 +138,17 @@ class AsyncEngine: ...@@ -130,17 +138,17 @@ class AsyncEngine:
random_seed=seed if sequence_start else None): random_seed=seed if sequence_start else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = self.tokenizer.decode(res[response_size:]) response = self.tokenizer.decode(res)[response_size:]
# response, history token len, input token len, gen token len # response, history token len,
# input token len, gen token len
yield GenOut(response, self.steps[str(session_id)], yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason) len(input_ids), tokens, finish_reason)
response_size = tokens response_size += len(response)
# update step # update step
self.steps[str(session_id)] += len(input_ids) + tokens self.steps[str(session_id)] += len(input_ids) + tokens
if sequence_end: if sequence_end or stop:
self.steps[str(session_id)] = 0 self.steps[str(session_id)] = 0
self.available[instance_id] = True
async def generate_openai( async def generate_openai(
self, self,
...@@ -180,13 +188,11 @@ class AsyncEngine: ...@@ -180,13 +188,11 @@ class AsyncEngine:
sequence_start = False sequence_start = False
generator = await self.get_generator(instance_id) generator = await self.get_generator(instance_id)
self.available[instance_id] = False self.available[instance_id] = False
if renew_session and str(session_id) in self.steps and self.steps[str( if renew_session: # renew a session
session_id)] > 0: # renew a session empty_input_ids = self.tokenizer.encode('')
empty_prompt = self.model.messages2prompt('', False)
empty_input_ids = self.tokenizer.encode(empty_prompt)
for outputs in generator.stream_infer(session_id=session_id, for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids], input_ids=[empty_input_ids],
request_output_len=1, request_output_len=0,
sequence_start=False, sequence_start=False,
sequence_end=True): sequence_end=True):
pass pass
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import os.path as osp
import random
import threading import threading
import time
from functools import partial from functools import partial
from typing import Sequence from typing import Sequence
import fire import fire
import gradio as gr import gradio as gr
from lmdeploy import turbomind as tm from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.model import MODELS
from lmdeploy.serve.gradio.css import CSS from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.turbomind.chatbot import Chatbot from lmdeploy.serve.turbomind.chatbot import Chatbot
THEME = gr.themes.Soft( THEME = gr.themes.Soft(
...@@ -19,6 +19,9 @@ THEME = gr.themes.Soft( ...@@ -19,6 +19,9 @@ THEME = gr.themes.Soft(
secondary_hue=gr.themes.colors.sky, secondary_hue=gr.themes.colors.sky,
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif']) font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
request: gr.Request): request: gr.Request):
...@@ -141,21 +144,17 @@ def run_server(triton_server_addr: str, ...@@ -141,21 +144,17 @@ def run_server(triton_server_addr: str,
) )
# a IO interface mananing global variables # a IO interface mananing variables
class InterFace: class InterFace:
tokenizer_model_path = None async_engine: AsyncEngine = None # for run_local
tokenizer = None restful_api_url: str = None # for run_restful
tm_model = None
request2instance = None
model_name = None
model = None
def chat_stream_local( def chat_stream_restful(
instruction: str, instruction: str,
state_chatbot: Sequence, state_chatbot: Sequence,
step: gr.State, cancel_btn: gr.Button,
nth_round: gr.State, reset_btn: gr.Button,
request: gr.Request, request: gr.Request,
): ):
"""Chat with AI assistant. """Chat with AI assistant.
...@@ -163,150 +162,343 @@ def chat_stream_local( ...@@ -163,150 +162,343 @@ def chat_stream_local(
Args: Args:
instruction (str): user's prompt instruction (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user request (gr.Request): the request from a user
""" """
from lmdeploy.turbomind.chat import valid_str
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = int(request.kwargs['client']['host'].replace('.', ''))
if str(session_id) not in InterFace.request2instance:
InterFace.request2instance[str(
session_id)] = InterFace.tm_model.create_instance()
llama_chatbot = InterFace.request2instance[str(session_id)]
seed = random.getrandbits(64)
bot_summarized_response = '' bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)] state_chatbot = state_chatbot + [(instruction, None)]
instruction = InterFace.model.get_prompt(instruction, nth_round == 1)
if step >= InterFace.tm_model.session_len: yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
raise gr.Error('WARNING: exceed session max length.' f'{bot_summarized_response}'.strip())
' Please end the session.')
input_ids = InterFace.tokenizer.encode(instruction) for response, tokens, finish_reason in get_streaming_response(
bot_response = llama_chatbot.stream_infer( instruction,
session_id, [input_ids], f'{InterFace.restful_api_url}/generate',
stream_output=True, instance_id=session_id,
request_output_len=512, request_output_len=512,
sequence_start=(nth_round == 1), sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
if finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if tokens < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=True):
pass
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
for out in get_streaming_response('',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=False, sequence_end=False,
step=step, stop=True):
stop=False, pass
top_k=40, time.sleep(0.5)
top_p=0.8, messages = []
temperature=0.8, for qa in state_chatbot:
repetition_penalty=1.0, messages.append(dict(role='user', content=qa[0]))
ignore_eos=False, if qa[1] is not None:
random_seed=seed if nth_round == 1 else None) messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(messages,
yield (state_chatbot, state_chatbot, step, nth_round, f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=True,
sequence_end=False):
pass
return (state_chatbot, disable_btn, enable_btn)
def run_restful(restful_api_url: str,
server_name: str = 'localhost',
server_port: int = 6006,
batch_size: int = 32):
"""chat with AI assistant through web ui.
Args:
restful_api_url (str): restufl api url
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
"""
InterFace.restful_api_url = restful_api_url
model_names = get_model_list(f'{restful_api_url}/v1/models')
model_name = ''
if isinstance(model_names, list) and len(model_names) > 0:
model_name = model_names[0]
else:
raise ValueError('gradio can find a suitable model from restful-api')
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
async def chat_stream_local(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip()) f'{bot_summarized_response}'.strip())
response_size = 0 async for outputs in InterFace.async_engine.generate(
for outputs in bot_response: instruction,
res, tokens = outputs[0] session_id,
# decode res stream_response=True,
response = InterFace.tokenizer.decode(res)[response_size:] sequence_start=(len(state_chatbot) == 1)):
response = valid_str(response) response = outputs.response
response_size += len(response) if outputs.finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if outputs.generate_token_len < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None: if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response) state_chatbot[-1] = (state_chatbot[-1][0], response)
else: else:
state_chatbot[-1] = (state_chatbot[-1][0], state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response state_chatbot[-1][1] + response
) # piece by piece ) # piece by piece
yield (state_chatbot, state_chatbot, step, nth_round, yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip()) f'{bot_summarized_response}'.strip())
step += len(input_ids) + tokens yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
nth_round += 1
yield (state_chatbot, state_chatbot, step, nth_round,
f'{bot_summarized_response}'.strip()) f'{bot_summarized_response}'.strip())
def reset_local_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, async def reset_local_func(instruction_txtbox: gr.Textbox,
step: gr.State, nth_round: gr.State, request: gr.Request): state_chatbot: gr.State, request: gr.Request):
"""reset the session. """reset the session.
Args: Args:
instruction_txtbox (str): user's prompt instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user request (gr.Request): the request from a user
""" """
state_chatbot = [] state_chatbot = []
step = 0
nth_round = 1
session_id = threading.current_thread().ident session_id = threading.current_thread().ident
if request is not None: if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', '')) session_id = int(request.kwargs['client']['host'].replace('.', ''))
InterFace.request2instance[str( # end the session
session_id)] = InterFace.tm_model.create_instance() async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=1,
stream_response=True,
sequence_start=False,
sequence_end=True):
pass
return ( return (
state_chatbot, state_chatbot,
state_chatbot, state_chatbot,
step,
nth_round,
gr.Textbox.update(value=''), gr.Textbox.update(value=''),
) )
async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=0,
stream_response=True,
sequence_start=False,
sequence_end=False,
stop=True):
pass
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
async for out in InterFace.async_engine.generate(messages,
session_id,
request_output_len=0,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
return (state_chatbot, disable_btn, enable_btn)
def run_local(model_path: str, def run_local(model_path: str,
server_name: str = 'localhost', server_name: str = 'localhost',
server_port: int = 6006): server_port: int = 6006,
batch_size: int = 4,
tp: int = 1):
"""chat with AI assistant through web ui. """chat with AI assistant through web ui.
Args: Args:
model_path (str): the path of the deployed model model_path (str): the path of the deployed model
server_name (str): the ip address of gradio server server_name (str): the ip address of gradio server
server_port (int): the port of gradio server server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
""" """
from lmdeploy.turbomind.tokenizer import Tokenizer InterFace.async_engine = AsyncEngine(model_path=model_path,
InterFace.tokenizer_model_path = osp.join(model_path, 'triton_models', instance_num=batch_size,
'tokenizer') tp=tp)
InterFace.tokenizer = Tokenizer(InterFace.tokenizer_model_path)
InterFace.tm_model = tm.TurboMind(model_path,
eos_id=InterFace.tokenizer.eos_token_id)
InterFace.request2instance = dict()
InterFace.model_name = InterFace.tm_model.model_name
InterFace.model = MODELS.get(InterFace.model_name)()
with gr.Blocks(css=CSS, theme=THEME) as demo: with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([]) state_chatbot = gr.State([])
nth_round = gr.State(1)
step = gr.State(0)
with gr.Column(elem_id='container'): with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground') gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=InterFace.model_name) chatbot = gr.Chatbot(
elem_id='chatbot',
label=InterFace.async_engine.tm_model.model_name)
instruction_txtbox = gr.Textbox( instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction', placeholder='Please input the instruction',
label='Instruction') label='Instruction')
with gr.Row(): with gr.Row():
gr.Button(value='Cancel') # noqa: E501 cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset') reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit( send_event = instruction_txtbox.submit(
chat_stream_local, chat_stream_local,
[instruction_txtbox, state_chatbot, step, nth_round], [instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, step, nth_round]) [state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit( instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''), lambda: gr.Textbox.update(value=''),
[], [],
[instruction_txtbox], [instruction_txtbox],
) )
cancel_btn.click(cancel_local_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click( reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
reset_local_func, [state_chatbot, chatbot, instruction_txtbox],
[instruction_txtbox, state_chatbot, step, nth_round],
[state_chatbot, chatbot, step, nth_round, instruction_txtbox],
cancels=[send_event]) cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}') print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch( demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10, max_threads=10,
share=True, share=True,
server_port=server_port, server_port=server_port,
...@@ -316,20 +508,33 @@ def run_local(model_path: str, ...@@ -316,20 +508,33 @@ def run_local(model_path: str,
def run(model_path_or_server: str, def run(model_path_or_server: str,
server_name: str = 'localhost', server_name: str = 'localhost',
server_port: int = 6006): server_port: int = 6006,
batch_size: int = 32,
tp: int = 1,
restful_api: bool = False):
"""chat with AI assistant through web ui. """chat with AI assistant through web ui.
Args: Args:
model_path_or_server (str): the path of the deployed model or the model_path_or_server (str): the path of the deployed model or the
tritonserver URL. The former is for directly running service with tritonserver URL or restful api URL. The former is for directly
gradio. The latter is for running with tritonserver running service with gradio. The latter is for running with
tritonserver by default. If the input URL is restful api. Please
enable another flag `restful_api`.
server_name (str): the ip address of gradio server server_name (str): the ip address of gradio server
server_port (int): the port of gradio server server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
restufl_api (bool): a flag for model_path_or_server
""" """
if ':' in model_path_or_server: if ':' in model_path_or_server:
if restful_api:
run_restful(model_path_or_server, server_name, server_port,
batch_size)
else:
run_server(model_path_or_server, server_name, server_port) run_server(model_path_or_server, server_name, server_port)
else: else:
run_local(model_path_or_server, server_name, server_port) run_local(model_path_or_server, server_name, server_port, batch_size,
tp)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -6,6 +6,15 @@ import fire ...@@ -6,6 +6,15 @@ import fire
import requests import requests
def get_model_list(api_url: str):
response = requests.get(api_url)
if hasattr(response, 'text'):
model_list = json.loads(response.text)
model_list = model_list.pop('data', [])
return [item['id'] for item in model_list]
return None
def get_streaming_response(prompt: str, def get_streaming_response(prompt: str,
api_url: str, api_url: str,
instance_id: int, instance_id: int,
...@@ -13,7 +22,8 @@ def get_streaming_response(prompt: str, ...@@ -13,7 +22,8 @@ def get_streaming_response(prompt: str,
stream: bool = True, stream: bool = True,
sequence_start: bool = True, sequence_start: bool = True,
sequence_end: bool = True, sequence_end: bool = True,
ignore_eos: bool = False) -> Iterable[List[str]]: ignore_eos: bool = False,
stop: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'} headers = {'User-Agent': 'Test Client'}
pload = { pload = {
'prompt': prompt, 'prompt': prompt,
...@@ -22,7 +32,8 @@ def get_streaming_response(prompt: str, ...@@ -22,7 +32,8 @@ def get_streaming_response(prompt: str,
'request_output_len': request_output_len, 'request_output_len': request_output_len,
'sequence_start': sequence_start, 'sequence_start': sequence_start,
'sequence_end': sequence_end, 'sequence_end': sequence_end,
'ignore_eos': ignore_eos 'ignore_eos': ignore_eos,
'stop': stop
} }
response = requests.post(api_url, response = requests.post(api_url,
headers=headers, headers=headers,
...@@ -33,9 +44,9 @@ def get_streaming_response(prompt: str, ...@@ -33,9 +44,9 @@ def get_streaming_response(prompt: str,
delimiter=b'\0'): delimiter=b'\0'):
if chunk: if chunk:
data = json.loads(chunk.decode('utf-8')) data = json.loads(chunk.decode('utf-8'))
output = data['text'] output = data.pop('text', '')
tokens = data['tokens'] tokens = data.pop('tokens', 0)
finish_reason = data['finish_reason'] finish_reason = data.pop('finish_reason', None)
yield output, tokens, finish_reason yield output, tokens, finish_reason
...@@ -46,7 +57,7 @@ def input_prompt(): ...@@ -46,7 +57,7 @@ def input_prompt():
return '\n'.join(iter(input, sentinel)) return '\n'.join(iter(input, sentinel))
def main(server_name: str, server_port: int, session_id: int = 0): def main(restful_api_url: str, session_id: int = 0):
nth_round = 1 nth_round = 1
while True: while True:
prompt = input_prompt() prompt = input_prompt()
...@@ -55,7 +66,7 @@ def main(server_name: str, server_port: int, session_id: int = 0): ...@@ -55,7 +66,7 @@ def main(server_name: str, server_port: int, session_id: int = 0):
else: else:
for output, tokens, finish_reason in get_streaming_response( for output, tokens, finish_reason in get_streaming_response(
prompt, prompt,
f'http://{server_name}:{server_port}/generate', f'{restful_api_url}/generate',
instance_id=session_id, instance_id=session_id,
request_output_len=512, request_output_len=512,
sequence_start=(nth_round == 1), sequence_start=(nth_round == 1),
......
...@@ -10,7 +10,7 @@ import uvicorn ...@@ -10,7 +10,7 @@ import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from lmdeploy.serve.openai.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.protocol import ( # noqa: E501 from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
...@@ -27,7 +27,7 @@ class VariableInterface: ...@@ -27,7 +27,7 @@ class VariableInterface:
request_hosts = [] request_hosts = []
app = FastAPI() app = FastAPI(docs_url='/')
def get_model_list(): def get_model_list():
...@@ -253,11 +253,12 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -253,11 +253,12 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation. - prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
- instance_id: determine which instance will be called. If not specified - instance_id: determine which instance will be called. If not specified
with a value other than -1, using host ip directly. with a value other than -1, using host ip directly.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
- stream: whether to stream the results or not.
- stop: whether to stop the session response or not.
- request_output_len (int): output token nums - request_output_len (int): output token nums
- step (int): the offset of the k/v cache - step (int): the offset of the k/v cache
- top_p (float): If set to float < 1, only the smallest set of most - top_p (float): If set to float < 1, only the smallest set of most
...@@ -283,6 +284,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None): ...@@ -283,6 +284,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
request_output_len=request.request_output_len, request_output_len=request.request_output_len,
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=request.top_k,
stop=request.stop,
temperature=request.temperature, temperature=request.temperature,
repetition_penalty=request.repetition_penalty, repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos) ignore_eos=request.ignore_eos)
......
...@@ -189,11 +189,12 @@ class EmbeddingsResponse(BaseModel): ...@@ -189,11 +189,12 @@ class EmbeddingsResponse(BaseModel):
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
"""Generate request.""" """Generate request."""
prompt: str prompt: Union[str, List[Dict[str, str]]]
instance_id: int = -1 instance_id: int = -1
sequence_start: bool = True sequence_start: bool = True
sequence_end: bool = False sequence_end: bool = False
stream: bool = False stream: bool = False
stop: bool = False
request_output_len: int = 512 request_output_len: int = 512
top_p: float = 0.8 top_p: float = 0.8
top_k: int = 40 top_k: int = 40
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment