Unverified Commit 18c386d9 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Support serving with gradio without communicating to TIS (#162)



* use local model for webui

* local model for app.py

* lint

* remove print

* add seed

* comments

* fixed seesion_id

* support turbomind batch inference

* update app.py

* lint and docstring

* move webui to serve/gradio

* update doc

* update doc

* update docstring and rmeove print conversition

* log

* Update docs/zh_cn/build.md
Co-authored-by: default avatarChen Xin <xinchen.tju@gmail.com>

* Update docs/en/build.md
Co-authored-by: default avatarChen Xin <xinchen.tju@gmail.com>

* use latest gradio

* fix

* replace partial with InterFace

* use host ip instead of coolie

---------
Co-authored-by: default avatarChen Xin <xinchen.tju@gmail.com>
parent 7a2128be
......@@ -50,11 +50,9 @@ And the request throughput of TurboMind is 30% higher than vLLM.
### Installation
Below are quick steps for installation:
Install lmdeploy with pip ( python 3.8+) or [from source](./docs/en/build.md)
```shell
conda create -n lmdeploy python=3.10 -y
conda activate lmdeploy
pip install lmdeploy
```
......@@ -92,7 +90,15 @@ python -m lmdeploy.turbomind.chat ./workspace
> **Note**<br />
> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=<num_gpu>` on `chat` to enable runtime TP.
#### Serving
#### Serving with gradio
```shell
python3 -m lmdeploy.serve.gradio.app ./workspace
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### Serving with Triton Inference Server
Launch inference server by:
......@@ -109,11 +115,9 @@ python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
or webui,
```shell
python3 -m lmdeploy.app {server_ip_addresss}:33337
python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and so on, you can find the guide from [here](docs/en/serving.md)
### Inference with PyTorch
......
......@@ -51,9 +51,9 @@ TurboMind 的 output token throughput 超过 2000 token/s, 整体比 DeepSpeed
### 安装
使用 pip ( python 3.8+) 安装 LMDeploy,或者[源码安装](./docs/zh_cn/build.md)
```shell
conda create -n lmdeploy python=3.10 -y
conda activate lmdeploy
pip install lmdeploy
```
......@@ -90,7 +90,15 @@ python3 -m lmdeploy.turbomind.chat ./workspace
> **Note**<br />
> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=<num_gpu>` 可以启动运行时 TP。
#### 部署推理服务
#### 启动 gradio server
```shell
python3 -m lmdeploy.serve.gradio.app ./workspace
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
#### 通过容器部署推理服务
使用下面的命令启动推理服务:
......@@ -107,11 +115,9 @@ python3 -m lmdeploy.serve.client {server_ip_addresss}:33337
也可以通过 WebUI 方式来对话:
```shell
python3 -m lmdeploy.app {server_ip_addresss}:33337
python3 -m lmdeploy.serve.gradio.app {server_ip_addresss}:33337
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
其他模型的部署方式,比如 LLaMA,LLaMA-2,vicuna等等,请参考[这里](docs/zh_cn/serving.md)
### 基于 PyTorch 的推理
......
## Build from source
- make sure local gcc version no less than 9, which can be conformed by `gcc --version`.
- install packages for compiling and running:
```shell
pip install -r requirements.txt
```
- install [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html), set environment variables:
```shell
export NCCL_ROOT_DIR=/path/to/nccl/build
export NCCL_LIBRARIES=/path/to/nccl/build/lib
```
- install rapidjson
- install openmpi, installing from source is recommended.
```shell
wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz
tar -xzf openmpi-*.tar.gz && cd openmpi-*
./configure --with-cuda
make -j$(nproc)
make install
```
- build and install lmdeploy:
```shell
mkdir build && cd build
sh ../generate.sh
```
### 源码安装
- 确保物理机环境的 gcc 版本不低于 9,可以通过`gcc --version`确认。
- 安装编译和运行依赖包:
```shell
pip install -r requirements.txt
```
- 安装 [nccl](https://docs.nvidia.com/deeplearning/nccl/install-guide/index.html),设置环境变量
```shell
export NCCL_ROOT_DIR=/path/to/nccl/build
export NCCL_LIBRARIES=/path/to/nccl/build/lib
```
- rapidjson 安装
- openmpi 安装, 推荐从源码安装:
```shell
wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.5.tar.gz
tar -xzf openmpi-*.tar.gz && cd openmpi-*
./configure --with-cuda
make -j$(nproc)
make install
```
- lmdeploy 编译安装:
```shell
mkdir build && cd build
sh ../generate.sh
```
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import random
import threading
from functools import partial
from typing import Sequence
......@@ -7,24 +9,12 @@ from typing import Sequence
import fire
import gradio as gr
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
from lmdeploy.serve.gradio.css import CSS
from lmdeploy.serve.turbomind.chatbot import Chatbot
CSS = """
#container {
width: 95%;
margin-left: auto;
margin-right: auto;
}
#chatbot {
height: 500px;
overflow: auto;
}
.chat_wrap_space {
margin-left: 0.5em
}
"""
from lmdeploy.turbomind.chat import valid_str
from lmdeploy.turbomind.tokenizer import Tokenizer
THEME = gr.themes.Soft(
primary_hue=gr.themes.colors.blue,
......@@ -32,38 +22,30 @@ THEME = gr.themes.Soft(
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
def chat_stream(instruction: str,
state_chatbot: Sequence,
llama_chatbot: Chatbot,
model_name: str = None):
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
request: gr.Request):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
llama_chatbot (Chatbot): the instance of a chatbot
request (gr.Request): the request from a user
model_name (str): the name of deployed model
"""
bot_summarized_response = ''
model_type = 'turbomind'
state_chatbot = state_chatbot + [(instruction, None)]
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())
for status, tokens, _ in bot_response:
if state_chatbot[-1][-1] is None or model_type != 'fairscale':
state_chatbot[-1] = (state_chatbot[-1][0], tokens)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + tokens
) # piece by piece
yield (state_chatbot, state_chatbot,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, '')
yield (state_chatbot, state_chatbot, f'{bot_summarized_response}'.strip())
return (state_chatbot, state_chatbot, '')
def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
......@@ -100,7 +82,12 @@ def cancel_func(
)
def run(triton_server_addr: str,
def add_instruction(instruction, state_chatbot):
state_chatbot = state_chatbot + [(instruction, None)]
return ('', state_chatbot)
def run_server(triton_server_addr: str,
server_name: str = 'localhost',
server_port: int = 6006):
"""chat with AI assistant through web ui.
......@@ -112,16 +99,13 @@ def run(triton_server_addr: str,
"""
with gr.Blocks(css=CSS, theme=THEME) as demo:
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
_chatbot = Chatbot(triton_server_addr,
log_level=log_level,
display=True)
model_name = _chatbot.model_name
chat_interface = partial(chat_stream, model_name=model_name)
llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([])
model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func,
model_name=model_name,
triton_server_addr=triton_server_addr)
llama_chatbot = gr.State(_chatbot)
state_chatbot = gr.State([])
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
......@@ -135,17 +119,10 @@ def run(triton_server_addr: str,
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_interface,
[instruction_txtbox, state_chatbot, llama_chatbot],
[state_chatbot, chatbot],
batch=False,
max_batch_size=1,
)
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(
chat_stream, [state_chatbot, llama_chatbot],
[state_chatbot, chatbot])
cancel_btn.click(cancel_func,
[instruction_txtbox, state_chatbot, llama_chatbot],
......@@ -157,6 +134,178 @@ def run(triton_server_addr: str,
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
# a IO interface mananing global variables
class InterFace:
tokenizer_model_path = None
tokenizer = None
tm_model = None
request2instance = None
model_name = None
model = None
def chat_stream_local(
instruction: str,
state_chatbot: Sequence,
step: gr.State,
nth_round: gr.State,
request: gr.Request,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
if str(session_id) not in InterFace.request2instance:
InterFace.request2instance[str(
session_id)] = InterFace.tm_model.create_instance()
llama_chatbot = InterFace.request2instance[str(session_id)]
seed = random.getrandbits(64)
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
instruction = InterFace.model.get_prompt(instruction, nth_round == 1)
if step >= InterFace.tm_model.session_len:
raise gr.Error('WARNING: exceed session max length.'
' Please end the session.')
input_ids = InterFace.tokenizer.encode(instruction)
bot_response = llama_chatbot.stream_infer(
session_id, [input_ids],
stream_output=True,
request_output_len=512,
sequence_start=(nth_round == 1),
sequence_end=False,
step=step,
stop=False,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=seed if nth_round == 1 else None)
yield (state_chatbot, state_chatbot, step, nth_round,
f'{bot_summarized_response}'.strip())
response_size = 0
for outputs in bot_response:
res, tokens = outputs[0]
# decode res
response = InterFace.tokenizer.decode(res)[response_size:]
response = valid_str(response)
response_size += len(response)
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, step, nth_round,
f'{bot_summarized_response}'.strip())
step += len(input_ids) + tokens
nth_round += 1
yield (state_chatbot, state_chatbot, step, nth_round,
f'{bot_summarized_response}'.strip())
def reset_local_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
step: gr.State, nth_round: gr.State, request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
step (gr.State): chat history length
nth_round (gr.State): round num
request (gr.Request): the request from a user
"""
state_chatbot = []
step = 0
nth_round = 1
session_id = threading.current_thread().ident
if request is not None:
session_id = int(request.kwargs['client']['host'].replace('.', ''))
InterFace.request2instance[str(
session_id)] = InterFace.tm_model.create_instance()
return (
state_chatbot,
state_chatbot,
step,
nth_round,
gr.Textbox.update(value=''),
)
def run_local(model_path: str,
server_name: str = 'localhost',
server_port: int = 6006):
"""chat with AI assistant through web ui.
Args:
model_path (str): the path of the deployed model
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
"""
InterFace.tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
InterFace.tokenizer = Tokenizer(InterFace.tokenizer_model_path)
InterFace.tm_model = tm.TurboMind(model_path,
eos_id=InterFace.tokenizer.eos_token_id)
InterFace.request2instance = dict()
InterFace.model_name = InterFace.tm_model.model_name
InterFace.model = MODELS.get(InterFace.model_name)()
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
nth_round = gr.State(1)
step = gr.State(0)
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=InterFace.model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
gr.Button(value='Cancel') # noqa: E501
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_local,
[instruction_txtbox, state_chatbot, step, nth_round],
[state_chatbot, chatbot, step, nth_round])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
reset_btn.click(
reset_local_func,
[instruction_txtbox, state_chatbot, step, nth_round],
[state_chatbot, chatbot, step, nth_round, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
share=True,
......@@ -165,5 +314,23 @@ def run(triton_server_addr: str,
)
def run(model_path_or_server: str,
server_name: str = 'localhost',
server_port: int = 6006):
"""chat with AI assistant through web ui.
Args:
model_path_or_server (str): the path of the deployed model or the
tritonserver URL. The former is for directly running service with
gradio. The latter is for running with tritonserver
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
"""
if ':' in model_path_or_server:
run_server(model_path_or_server, server_name, server_port)
else:
run_local(model_path_or_server, server_name, server_port)
if __name__ == '__main__':
fire.Fire(run)
# Copyright (c) OpenMMLab. All rights reserved.
CSS = """
#container {
width: 95%;
margin-left: auto;
margin-right: auto;
}
#chatbot {
height: 500px;
overflow: auto;
}
.chat_wrap_space {
margin-left: 0.5em
}
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment