Unverified Commit 4279d8ca authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Enable the Gradio server to call inference services through the RESTful API (#287)



* app use async engine

* add stop logic

* app update cancel

* app support restful-api

* update doc and use the right model name

* set doc url root

* add comments

* add an example

* renew_session

* update readme.md

* resolve comments

* Update restful_api.md

* Update restful_api.md

* Update restful_api.md

---------
Co-authored-by: default avatartpoisonooo <khj.application@aliyun.com>
parent 81f29837
......@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### Serving with Restful API
Launch inference server by:
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1
```
Then, you can communicate with it by command line,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
or webui,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
Refer to [restful_api.md](docs/en/restful_api.md) for more details.
#### Serving with Triton Inference Server
Launch inference server by:
......
......@@ -133,6 +133,32 @@ python3 -m lmdeploy.serve.gradio.app ./workspace
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### 通过 Restful API 部署服务
使用下面的命令启动推理服务:
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_ip server_port --instance_num 32 --tp 1
```
你可以通过命令行方式与推理服务进行对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
也可以通过 WebUI 方式来对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)
#### 通过容器部署推理服务
使用下面的命令启动推理服务:
......
......@@ -3,10 +3,10 @@
### Launch Service
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1
python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1
```
Then, the user can open the swagger UI: http://{server_name}:{server_port}/docs for the detailed api usage.
Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage.
We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try
our own api which provides more arguments for users to modify. The performance is comparatively better.
......@@ -50,16 +50,29 @@ def get_streaming_response(prompt: str,
for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_name}:{server_port}/generate", 0,
"Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512):
print(output, end='')
```
### Golang/Rust
### Java/Golang/Rust
Golang can also build a http request to use the service. You may refer
to [the blog](https://pkg.go.dev/net/http) for details to build own client.
Besides, Rust supports building a client in [many ways](https://blog.logrocket.com/best-rust-http-client/).
May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client.
Here is an example:
```shell
$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust
$ ls rust/*
rust/Cargo.toml rust/git_push.sh rust/README.md
rust/docs:
ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md
DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md
rust/src:
apis lib.rs models
```
### cURL
......@@ -68,13 +81,13 @@ cURL is a tool for observing the output of the api.
List Models:
```bash
curl http://{server_name}:{server_port}/v1/models
curl http://{server_ip}:{server_port}/v1/models
```
Generate:
```bash
curl http://{server_name}:{server_port}/generate \
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -87,7 +100,7 @@ curl http://{server_name}:{server_port}/generate \
Chat Completions:
```bash
curl http://{server_name}:{server_port}/v1/chat/completions \
curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -98,7 +111,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \
Embeddings:
```bash
curl http://{server_name}:{server_port}/v1/embeddings \
curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -106,6 +119,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \
}'
```
### CLI client
There is a client script for restful api server.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
### webui
You can also test restful-api through webui.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
### FAQ
1. When user got `"finish_reason":"length"` which means the session is too long to be continued.
......
......@@ -5,10 +5,10 @@
运行脚本
```shell
python3 -m lmdeploy.serve.openai.api_server ./workspace server_name server_port --instance_num 32 --tp 1
python3 -m lmdeploy.serve.openai.api_server ./workspace 0.0.0.0 server_port --instance_num 32 --tp 1
```
然后用户可以打开 swagger UI: http://{server_name}:{server_port}/docs 详细查看所有的 API 及其使用方法。
然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。
我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`
它有更好的性能,提供更多的参数让用户自定义修改。
......@@ -52,15 +52,29 @@ def get_streaming_response(prompt: str,
for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_name}:{server_port}/generate", 0,
"Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512):
print(output, end='')
```
### Golang/Rust
### Java/Golang/Rust
Golang 也可以建立 http 请求使用启动的服务,用户可以参考[这篇博客](https://pkg.go.dev/net/http)构建自己的客户端。
Rust 也有许多[方法](https://blog.logrocket.com/best-rust-http-client/)构建客户端,使用服务。
可以使用代码生成工具 [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli)`http://{server_ip}:{server_port}/openapi.json` 转成 java/rust/golang 客户端。
下面是一个使用示例:
```shell
$ docker run -it --rm -v ${PWD}:/local openapitools/openapi-generator-cli generate -i /local/openapi.json -g rust -o /local/rust
$ ls rust/*
rust/Cargo.toml rust/git_push.sh rust/README.md
rust/docs:
ChatCompletionRequest.md EmbeddingsRequest.md HttpValidationError.md LocationInner.md Prompt.md
DefaultApi.md GenerateRequest.md Input.md Messages.md ValidationError.md
rust/src:
apis lib.rs models
```
### cURL
......@@ -69,13 +83,13 @@ cURL 也可以用于查看 API 的输出结果
查看模型列表:
```bash
curl http://{server_name}:{server_port}/v1/models
curl http://{server_ip}:{server_port}/v1/models
```
使用 generate:
```bash
curl http://{server_name}:{server_port}/generate \
curl http://{server_ip}:{server_port}/generate \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -88,7 +102,7 @@ curl http://{server_name}:{server_port}/generate \
Chat Completions:
```bash
curl http://{server_name}:{server_port}/v1/chat/completions \
curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -99,7 +113,7 @@ curl http://{server_name}:{server_port}/v1/chat/completions \
Embeddings:
```bash
curl http://{server_name}:{server_port}/v1/embeddings \
curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
......@@ -107,6 +121,26 @@ curl http://{server_name}:{server_port}/v1/embeddings \
}'
```
### CLI client
restful api 服务可以通过客户端测试,例如
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
python -m lmdeploy.serve.openai.api_client restful_api_url
```
### webui
也可以直接用 webui 测试使用 restful-api。
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
# server_ip 和 server_port 是用来提供 gradio ui 访问服务的
# 例子: python -m lmdeploy.serve.gradio.app http://localhost:23333 localhost 6006 --restful_api True
python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True
```
### FAQ
1. 当返回结果结束原因为 `"finish_reason":"length"`,这表示回话长度超过最大值。
......
......@@ -3,11 +3,11 @@ import asyncio
import dataclasses
import os.path as osp
import random
from contextlib import contextmanager
from typing import Literal, Optional
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS, BaseModel
from lmdeploy.turbomind.tokenizer import Tokenizer
@dataclasses.dataclass
......@@ -30,6 +30,7 @@ class AsyncEngine:
"""
def __init__(self, model_path, instance_num=32, tp=1) -> None:
from lmdeploy.turbomind.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
......@@ -46,15 +47,22 @@ class AsyncEngine:
self.starts = [None] * instance_num
self.steps = {}
@contextmanager
def safe_run(self, instance_id: int, stop: bool = False):
self.available[instance_id] = False
yield
self.available[instance_id] = True
async def get_embeddings(self, prompt):
prompt = self.model.get_prompt(prompt)
input_ids = self.tokenizer.encode(prompt)
return input_ids
async def get_generator(self, instance_id):
async def get_generator(self, instance_id: int, stop: bool = False):
"""Only return the model instance if it is available."""
while self.available[instance_id] is False:
await asyncio.sleep(0.1)
if not stop:
while self.available[instance_id] is False:
await asyncio.sleep(0.1)
return self.generators[instance_id]
async def generate(
......@@ -104,43 +112,43 @@ class AsyncEngine:
prompt = self.model.messages2prompt(messages, sequence_start)
input_ids = self.tokenizer.encode(prompt)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
if not sequence_end and self.steps[str(session_id)] + len(
input_ids) >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason)
else:
generator = await self.get_generator(instance_id)
self.available[instance_id] = False
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=sequence_end,
step=self.steps[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res[response_size:])
# response, history token len, input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens
generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=sequence_end,
step=self.steps[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res)[response_size:]
# response, history token len,
# input token len, gen token len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size += len(response)
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
if sequence_end:
self.steps[str(session_id)] = 0
self.available[instance_id] = True
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop:
self.steps[str(session_id)] = 0
async def generate_openai(
self,
......@@ -180,13 +188,11 @@ class AsyncEngine:
sequence_start = False
generator = await self.get_generator(instance_id)
self.available[instance_id] = False
if renew_session and str(session_id) in self.steps and self.steps[str(
session_id)] > 0: # renew a session
empty_prompt = self.model.messages2prompt('', False)
empty_input_ids = self.tokenizer.encode(empty_prompt)
if renew_session: # renew a session
empty_input_ids = self.tokenizer.encode('')
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids],
request_output_len=1,
request_output_len=0,
sequence_start=False,
sequence_end=True):
pass
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import random
import threading
import time
from functools import partial
from typing import Sequence
import fire
import gradio as gr
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.turbomind.chatbot import Chatbot
THEME = gr.themes.Soft(
......@@ -19,6 +19,9 @@ THEME = gr.themes.Soft(
secondary_hue=gr.themes.colors.sky,
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
request: gr.Request):
......@@ -141,21 +144,17 @@ def run_server(triton_server_addr: str,
)
# a IO interface mananing global variables
# a IO interface mananing variables
class InterFace:
tokenizer_model_path = None
tokenizer = None
tm_model = None
request2instance = None
model_name = None
model = None
async_engine: AsyncEngine = None # for run_local
restful_api_url: str = None # for run_restful
def chat_stream_local(
def chat_stream_restful(
instruction: str,
state_chatbot: Sequence,
step: gr.State,
nth_round: gr.State,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant.
......@@ -163,173 +162,379 @@ def chat_stream_local(
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user
"""
from lmdeploy.turbomind.chat import valid_str
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
if str(session_id) not in InterFace.request2instance:
InterFace.request2instance[str(
session_id)] = InterFace.tm_model.create_instance()
llama_chatbot = InterFace.request2instance[str(session_id)]
seed = random.getrandbits(64)
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
instruction = InterFace.model.get_prompt(instruction, nth_round == 1)
if step >= InterFace.tm_model.session_len:
raise gr.Error('WARNING: exceed session max length.'
' Please end the session.')
input_ids = InterFace.tokenizer.encode(instruction)
bot_response = llama_chatbot.stream_infer(
session_id, [input_ids],
stream_output=True,
request_output_len=512,
sequence_start=(nth_round == 1),
sequence_end=False,
step=step,
stop=False,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=seed if nth_round == 1 else None)
yield (state_chatbot, state_chatbot, step, nth_round,
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
for response, tokens, finish_reason in get_streaming_response(
instruction,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=512,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
if finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if tokens < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=True):
pass
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
for out in get_streaming_response('',
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
time.sleep(0.5)
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(messages,
f'{InterFace.restful_api_url}/generate',
instance_id=session_id,
request_output_len=0,
sequence_start=True,
sequence_end=False):
pass
return (state_chatbot, disable_btn, enable_btn)
def run_restful(restful_api_url: str,
server_name: str = 'localhost',
server_port: int = 6006,
batch_size: int = 32):
"""chat with AI assistant through web ui.
Args:
restful_api_url (str): restufl api url
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
"""
InterFace.restful_api_url = restful_api_url
model_names = get_model_list(f'{restful_api_url}/v1/models')
model_name = ''
if isinstance(model_names, list) and len(model_names) > 0:
model_name = model_names[0]
else:
raise ValueError('gradio can find a suitable model from restful-api')
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
async def chat_stream_local(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
response_size = 0
for outputs in bot_response:
res, tokens = outputs[0]
# decode res
response = InterFace.tokenizer.decode(res)[response_size:]
response = valid_str(response)
response_size += len(response)
async for outputs in InterFace.async_engine.generate(
instruction,
session_id,
stream_response=True,
sequence_start=(len(state_chatbot) == 1)):
response = outputs.response
if outputs.finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if outputs.generate_token_len < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, step, nth_round,
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
step += len(input_ids) + tokens
nth_round += 1
yield (state_chatbot, state_chatbot, step, nth_round,
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
def reset_local_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
step: gr.State, nth_round: gr.State, request: gr.Request):
async def reset_local_func(instruction_txtbox: gr.Textbox,
state_chatbot: gr.State, request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user
"""
state_chatbot = []
step = 0
nth_round = 1
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
InterFace.request2instance[str(
session_id)] = InterFace.tm_model.create_instance()
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=1,
stream_response=True,
sequence_start=False,
sequence_end=True):
pass
return (
state_chatbot,
state_chatbot,
step,
nth_round,
gr.Textbox.update(value=''),
)
async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=0,
stream_response=True,
sequence_start=False,
sequence_end=False,
stop=True):
pass
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
async for out in InterFace.async_engine.generate(messages,
session_id,
request_output_len=0,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
return (state_chatbot, disable_btn, enable_btn)
def run_local(model_path: str,
server_name: str = 'localhost',
server_port: int = 6006):
server_port: int = 6006,
batch_size: int = 4,
tp: int = 1):
"""chat with AI assistant through web ui.
Args:
model_path (str): the path of the deployed model
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
"""
from lmdeploy.turbomind.tokenizer import Tokenizer
InterFace.tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
InterFace.tokenizer = Tokenizer(InterFace.tokenizer_model_path)
InterFace.tm_model = tm.TurboMind(model_path,
eos_id=InterFace.tokenizer.eos_token_id)
InterFace.request2instance = dict()
InterFace.model_name = InterFace.tm_model.model_name
InterFace.model = MODELS.get(InterFace.model_name)()
InterFace.async_engine = AsyncEngine(model_path=model_path,
instance_num=batch_size,
tp=tp)
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
nth_round = gr.State(1)
step = gr.State(0)
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=InterFace.model_name)
chatbot = gr.Chatbot(
elem_id='chatbot',
label=InterFace.async_engine.tm_model.model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
gr.Button(value='Cancel') # noqa: E501
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_local,
[instruction_txtbox, state_chatbot, step, nth_round],
[state_chatbot, chatbot, step, nth_round])
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_local_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(
reset_local_func,
[instruction_txtbox, state_chatbot, step, nth_round],
[state_chatbot, chatbot, step, nth_round, instruction_txtbox],
cancels=[send_event])
reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
def run(model_path_or_server: str,
server_name: str = 'localhost',
server_port: int = 6006):
server_port: int = 6006,
batch_size: int = 32,
tp: int = 1,
restful_api: bool = False):
"""chat with AI assistant through web ui.
Args:
model_path_or_server (str): the path of the deployed model or the
tritonserver URL. The former is for directly running service with
gradio. The latter is for running with tritonserver
tritonserver URL or restful api URL. The former is for directly
running service with gradio. The latter is for running with
tritonserver by default. If the input URL is restful api. Please
enable another flag `restful_api`.
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
restufl_api (bool): a flag for model_path_or_server
"""
if ':' in model_path_or_server:
run_server(model_path_or_server, server_name, server_port)
if restful_api:
run_restful(model_path_or_server, server_name, server_port,
batch_size)
else:
run_server(model_path_or_server, server_name, server_port)
else:
run_local(model_path_or_server, server_name, server_port)
run_local(model_path_or_server, server_name, server_port, batch_size,
tp)
if __name__ == '__main__':
......
......@@ -6,6 +6,15 @@ import fire
import requests
def get_model_list(api_url: str):
response = requests.get(api_url)
if hasattr(response, 'text'):
model_list = json.loads(response.text)
model_list = model_list.pop('data', [])
return [item['id'] for item in model_list]
return None
def get_streaming_response(prompt: str,
api_url: str,
instance_id: int,
......@@ -13,7 +22,8 @@ def get_streaming_response(prompt: str,
stream: bool = True,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False) -> Iterable[List[str]]:
ignore_eos: bool = False,
stop: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
pload = {
'prompt': prompt,
......@@ -22,7 +32,8 @@ def get_streaming_response(prompt: str,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
'ignore_eos': ignore_eos
'ignore_eos': ignore_eos,
'stop': stop
}
response = requests.post(api_url,
headers=headers,
......@@ -33,9 +44,9 @@ def get_streaming_response(prompt: str,
delimiter=b'\0'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
tokens = data['tokens']
finish_reason = data['finish_reason']
output = data.pop('text', '')
tokens = data.pop('tokens', 0)
finish_reason = data.pop('finish_reason', None)
yield output, tokens, finish_reason
......@@ -46,7 +57,7 @@ def input_prompt():
return '\n'.join(iter(input, sentinel))
def main(server_name: str, server_port: int, session_id: int = 0):
def main(restful_api_url: str, session_id: int = 0):
nth_round = 1
while True:
prompt = input_prompt()
......@@ -55,7 +66,7 @@ def main(server_name: str, server_port: int, session_id: int = 0):
else:
for output, tokens, finish_reason in get_streaming_response(
prompt,
f'http://{server_name}:{server_port}/generate',
f'{restful_api_url}/generate',
instance_id=session_id,
request_output_len=512,
sequence_start=(nth_round == 1),
......
......@@ -10,7 +10,7 @@ import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from lmdeploy.serve.openai.async_engine import AsyncEngine
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
......@@ -27,7 +27,7 @@ class VariableInterface:
request_hosts = []
app = FastAPI()
app = FastAPI(docs_url='/')
def get_model_list():
......@@ -253,11 +253,12 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
- instance_id: determine which instance will be called. If not specified
with a value other than -1, using host ip directly.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
- stream: whether to stream the results or not.
- stop: whether to stop the session response or not.
- request_output_len (int): output token nums
- step (int): the offset of the k/v cache
- top_p (float): If set to float < 1, only the smallest set of most
......@@ -283,6 +284,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
request_output_len=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
......
......@@ -189,11 +189,12 @@ class EmbeddingsResponse(BaseModel):
class GenerateRequest(BaseModel):
"""Generate request."""
prompt: str
prompt: Union[str, List[Dict[str, str]]]
instance_id: int = -1
sequence_start: bool = True
sequence_end: bool = False
stream: bool = False
stop: bool = False
request_output_len: int = 512
top_p: float = 0.8
top_k: int = 40
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment