"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "987d34b0cf8d6cd8725258332fcfc8c54529b1ab"
Unverified Commit 373bd013 authored by AllentDan's avatar AllentDan Committed by GitHub
Browse files

Improve api_server and webui usage (#544)

* make IPv6 compatible, safe run for coroutine interrupting

* instance_id -> session_id and fix api_client.py

* update doc

* remove useless faq

* safe ip mapping

* update app.py

* WIP completion

* completion

* update doc

* disable interactive mode for /v1/chat/completions

* docstring

* docstring

* refactor gradio

* update gradio

* udpate

* update doc

* rename

* session_id default -1

* missed two files

* add a APIClient

* add chat func for APIClient

* refine

* add concurrent function

* sequence_start, sequence_end --> interactive_mode

* update doc

* comments

* doc

* better text completion

* remove /v1/embeddings

* comments

* deprecate generate and use /v1/interactive/completions

* /v1/interactive/completion -> /v1/chat/interactive

* embeddings

* rename

* remove wrong arg description

* docstring

* fix

* update cli

* update doc

* strict session_len limit condition

* pass model args to api_server
parent 56942c43
......@@ -157,16 +157,16 @@ Then, you can communicate with it by command line,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client restful_api_url
lmdeploy serve api_client api_server_url
```
or webui,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
Refer to [restful_api.md](docs/en/restful_api.md) for more details.
......
......@@ -157,16 +157,16 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client restful_api_url
lmdeploy serve api_client api_server_url
```
也可以通过 WebUI 方式来对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port${server_port} --restful_api True
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)
......
......@@ -2,48 +2,15 @@ import json
import multiprocessing as mp
import random
import time
from typing import Iterable, List
import fire
import numpy as np
import requests
from lmdeploy.serve.openai.api_client import get_streaming_response
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
def get_streaming_response(prompt: str,
api_url: str,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
sequence_end: bool = False,
ignore_eos: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
pload = {
'prompt': prompt,
'stream': stream,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
'ignore_eos': ignore_eos
}
response = requests.post(api_url,
headers=headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
tokens = data['tokens']
yield output, tokens
def infer(server_addr: str, session_id: int, req_queue: mp.Queue,
res_que: mp.Queue):
stats = []
......@@ -55,13 +22,12 @@ def infer(server_addr: str, session_id: int, req_queue: mp.Queue,
timestamps = []
tokens = []
start = time.perf_counter()
for res, token in get_streaming_response(
for res, token, status in get_streaming_response(
prompt,
server_addr,
session_id,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
interactive_mode=False):
timestamps.append(time.perf_counter())
tokens.append(token)
......@@ -80,13 +46,11 @@ def warmup(server_addr: str,
def _infer(server_addr, session_id):
for _ in range(warmup_round):
for _, _ in get_streaming_response(
'',
server_addr,
session_id,
request_output_len=output_seqlen,
sequence_start=True,
sequence_end=True):
for _ in get_streaming_response('',
server_addr,
session_id,
request_output_len=output_seqlen,
interactive_mode=False):
continue
_start = time.perf_counter()
......@@ -150,7 +114,7 @@ def main(server_addr: str,
concurrency: int = 1,
session_len: int = 2048,
samples: int = 1000):
api_url = server_addr + '/generate'
api_url = server_addr + '/v1/chat/interactive'
warmup(api_url, concurrency, session_len - 1)
req_queue, n_req = read_dataset(tokenizer_path, dataset_path, samples,
session_len)
......
......@@ -24,7 +24,8 @@ def sample_requests(
dataset = [data for data in dataset if len(data['conversations']) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data['conversations'][0]['value'],
data['conversations'][1]['value']) for data in dataset]
data['conversations'][1]['value'])
for data in dataset][:num_requests * 2] # speed up encoding
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
......
......@@ -7,52 +7,57 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv
```
Then, the user can open the swagger UI: `http://{server_ip}:{server_port}` for the detailed api usage.
We provide four restful api in total. Three of them are in OpenAI format. However, we recommend users try
our own api which provides more arguments for users to modify. The performance is comparatively better.
We provide four restful api in total. Three of them are in OpenAI format.
- /v1/chat/completions
- /v1/models
- /v1/completions
However, we recommend users try
our own api `/v1/chat/interactive` which provides more arguments for users to modify. The performance is comparatively better.
**Note** please, if you want to launch multiple requests, you'd better set different `session_id` for both
`/v1/chat/completions` and `/v1/chat/interactive` apis. Or, we will set them random values.
### python
Here is an example for our own api `generate`.
We have integrated the client-side functionalities of these services into the `APIClient` class. Below are some examples demonstrating how to invoke the `api_server` service on the client side.
If you want to use the `/v1/chat/completions` endpoint, you can try the following code:
```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
model_name = api_client.available_models[0]
messages = [{"role": "user", "content": "Say this is a test!"}]
for item in api_client.chat_completions_v1(model=model_name, messages=messages):
print(item)
```
For the `/v1/completions` endpoint. If you want to use the `/v1/completions` endpoint, you can try:
```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
model_name = api_client.available_models[0]
for item in api_client.completions_v1(model=model_name, prompt='hi'):
print(item)
```
Lmdeploy supports maintaining session histories on the server for `/v1/chat/interactive` api. We disable the
feature by default.
- On interactive mode, the chat history is kept on the server. In a multiple rounds of conversation, you should set
`interactive_mode = True` and the same `session_id` (can't be -1, it's the default number) to `/v1/chat/interactive` for requests.
- On normal mode, no chat history is kept on the server.
The interactive mode can be controlled by the `interactive_mode` boolean parameter. The following is an example of normal mode. If you want to experience the interactive mode, simply pass in `interactive_mode=True`.
```python
import json
import requests
from typing import Iterable, List
def get_streaming_response(prompt: str,
api_url: str,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
pload = {
'prompt': prompt,
'stream': stream,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
'ignore_eos': ignore_eos
}
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
tokens = data['tokens']
yield output, tokens
for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512):
print(output, end='')
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
for item in api_client.generate(prompt='hi'):
print(item)
```
### Java/Golang/Rust
......@@ -84,16 +89,15 @@ List Models:
curl http://{server_ip}:{server_port}/v1/models
```
Generate:
Interactive Chat:
```bash
curl http://{server_ip}:{server_port}/generate \
curl http://{server_ip}:{server_port}/v1/chat/interactive \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"session_id": 1,
"sequence_start": true,
"sequence_end": true
"interactive_mode": true
}'
```
......@@ -104,19 +108,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"messages": [{"role": "user", "content": "Hello! Ho are you?"}]
"messages": [{"role": "user", "content": "Hello! How are you?"}]
}'
```
Embeddings:
Text Completions:
```bash
curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \
```shell
curl http://{server_ip}:{server_port}/v1/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "internlm-chat-7b",
"input": "Hello world!"
}'
"model": "llama",
"prompt": "two steps to build a house:"
}'
```
### CLI client
......@@ -125,7 +129,7 @@ There is a client script for restful api server.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client restful_api_url
lmdeploy serve api_client api_server_url
```
### webui
......@@ -133,10 +137,10 @@ lmdeploy serve api_client restful_api_url
You can also test restful-api through webui.
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
### FAQ
......@@ -146,10 +150,6 @@ lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port $
2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service.
3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards.
4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue,
- kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses
3. When the request with the same `session_id` to `/v1/chat/interactive` got a empty return value and a negative `tokens`, please consider setting `interactive_mode=false` to restart the session.
5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions`
4. The `/v1/chat/interactive` api disables engaging in multiple rounds of conversation by default. The input argument `prompt` consists of either single strings or entire chat histories.
......@@ -97,16 +97,16 @@ Then, you can communicate with it by command line,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client restful_api_url
lmdeploy serve api_client api_server_url
```
or through webui after launching gradio,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
# server_ip and server_port here are for gradio ui
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True
# example: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
Regarding the detailed information of RESTful API, you can refer to [restful_api.md](../restful_api.md).
......@@ -9,52 +9,52 @@ lmdeploy serve api_server ./workspace 0.0.0.0 --server_port ${server_port} --ins
```
然后用户可以打开 swagger UI: `http://{server_ip}:{server_port}` 详细查看所有的 API 及其使用方法。
我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。不过,我们建议用户用我们提供的另一个 API: `generate`
我们一共提供四个 restful api,其中三个仿照 OpenAI 的形式。
- /v1/chat/completions
- /v1/models
- /v1/completions
不过,我们建议用户用我们提供的另一个 API: `/v1/chat/interactive`
它有更好的性能,提供更多的参数让用户自定义修改。
### python
这是一个 python 示例,展示如何使用 `generate`
我们将这些服务的客户端功能集成在 `APIClient` 类中。下面是一些例子,展示如何在客户端调用 `api_server` 服务。
如果你想用 `/v1/chat/completions` 接口,你可以尝试下面代码:
```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
model_name = api_client.available_models[0]
messages = [{"role": "user", "content": "Say this is a test!"}]
for item in api_client.chat_completions_v1(model=model_name, messages=messages):
print(item)
```
如果你想用 `/v1/completions` 接口,你可以尝试:
```python
import json
import requests
from typing import Iterable, List
def get_streaming_response(prompt: str,
api_url: str,
session_id: int,
request_output_len: int,
stream: bool = True,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
pload = {
'prompt': prompt,
'stream': stream,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
'ignore_eos': ignore_eos
}
response = requests.post(
api_url, headers=headers, json=pload, stream=stream)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
if chunk:
data = json.loads(chunk.decode('utf-8'))
output = data['text']
tokens = data['tokens']
yield output, tokens
for output, tokens in get_streaming_response(
"Hi, how are you?", "http://{server_ip}:{server_port}/generate", 0,
512):
print(output, end='')
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
model_name = api_client.available_models[0]
for item in api_client.completions_v1(model=model_name, prompt='hi'):
print(item)
```
LMDeploy 的 `/v1/chat/interactive` api 支持将对话内容管理在服务端,但是我们默认关闭。如果想尝试,请阅读以下介绍:
- 交互模式下,对话历史保存在 server。在一次完整的多轮对话中,所有请求设置`interactive_mode = True`, `session_id`保持相同 (不为 -1,这是缺省值)。
- 非交互模式下,server 不保存历史记录。
交互模式可以通过 `interactive_mode` 布尔量参数控制。下面是一个普通模式的例子,
如果要体验交互模式,将 `interactive_mode=True` 传入即可。
```python
from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient('http://{server_ip}:{server_port}')
for item in api_client.generate(prompt='hi'):
print(item)
```
### Java/Golang/Rust
......@@ -86,16 +86,15 @@ cURL 也可以用于查看 API 的输出结果
curl http://{server_ip}:{server_port}/v1/models
```
使用 generate:
Interactive Chat:
```bash
curl http://{server_ip}:{server_port}/generate \
curl http://{server_ip}:{server_port}/v1/chat/interactive \
-H "Content-Type: application/json" \
-d '{
"prompt": "Hello! How are you?",
"session_id": 1,
"sequence_start": true,
"sequence_end": true
"interactive_mode": true
}'
```
......@@ -106,19 +105,19 @@ curl http://{server_ip}:{server_port}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "internlm-chat-7b",
"messages": [{"role": "user", "content": "Hello! Ho are you?"}]
"messages": [{"role": "user", "content": "Hello! How are you?"}]
}'
```
Embeddings:
Text Completions:
```bash
curl http://{server_ip}:{server_port}/v1/embeddings \
-H "Content-Type: application/json" \
```shell
curl http://{server_ip}:{server_port}/v1/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "internlm-chat-7b",
"input": "Hello world!"
}'
"model": "llama",
"prompt": "two steps to build a house:"
}'
```
### CLI client
......@@ -126,8 +125,8 @@ curl http://{server_ip}:{server_port}/v1/embeddings \
restful api 服务可以通过客户端测试,例如
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
lmdeploy serve api_client restful_api_url
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client api_server_url
```
### webui
......@@ -135,10 +134,10 @@ lmdeploy serve api_client restful_api_url
也可以直接用 webui 测试使用 restful-api。
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
# server_ip 和 server_port 是用来提供 gradio ui 访问服务的
# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True
# api_server_url 就是 api_server 产生的,比如 http://localhost:23333
# server_name 和 server_port 是用来提供 gradio ui 访问服务的
# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
### FAQ
......@@ -148,12 +147,6 @@ lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port $
2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数
3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false`
4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数:
- 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。
3. 当同一个 `session_id` 的请求给 `/v1/chat/interactive` 函数后,出现返回空字符串和负值的 `tokens`,应该是 `session_id` 混乱了,可以先将交互模式关闭,再重新开启。
5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。
两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置
`renew_session: true` 传入 `v1/chat/completions`
4. `/v1/chat/interactive` api 支持多轮对话, 但是默认关闭。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。
......@@ -98,17 +98,17 @@ lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${serv
你可以用命令行,在控制台与 server 通信:
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
lmdeploy serve api_client restful_api_url
# api_server_url 就是 api_server 产生的,比如 http://localhost:23333
lmdeploy serve api_client api_server_url
```
或者,启动 gradio,在 webui 的聊天对话框中,与 codellama 交流:
```shell
# restful_api_url 就是 api_server 产生的,比如 http://localhost:23333
# api_server_url 就是 api_server 产生的,比如 http://localhost:23333
# server_ip 和 server_port 是用来提供 gradio ui 访问服务的
# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006 --restful_api True
lmdeploy serve gradio restful_api_url --server_name ${server_ip} --server_port ${server_port} --restful_api True
# 例子: lmdeploy serve gradio http://localhost:23333 --server_name localhost --server_port 6006
lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port ${gradio_ui_port}
```
关于 RESTful API的详细介绍,请参考[这份](../restful_api.md)文档。
......@@ -7,7 +7,7 @@ class SubCliServe(object):
def gradio(self,
model_path_or_server: str,
server_name: str = 'localhost',
server_name: str = '0.0.0.0',
server_port: int = 6006,
batch_size: int = 32,
tp: int = 1,
......@@ -18,8 +18,8 @@ class SubCliServe(object):
lmdeploy serve gradio ./workspace
Example 2:
lmdeploy serve gradio http://localhost:23333
--server_name localhost
lmdeploy serve gradio http://0.0.0.0:23333
--server_name 0.0.0.0
--server_port 6006
--restful_api True
......@@ -48,7 +48,7 @@ class SubCliServe(object):
def api_server(self,
model_path: str,
server_name: str = 'localhost',
server_name: str = '0.0.0.0',
server_port: int = 23333,
instance_num: int = 32,
tp: int = 1,
......
......@@ -28,7 +28,7 @@ class AsyncEngine:
tp (int): tensor parallel
"""
def __init__(self, model_path, instance_num=32, tp=1) -> None:
def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None:
from lmdeploy import turbomind as tm
from lmdeploy.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models',
......@@ -42,13 +42,14 @@ class AsyncEngine:
self.tm_model.create_instance() for i in range(instance_num)
]
self.instance_num = instance_num
self.model: BaseModel = MODELS.get(self.tm_model.model_name)()
self.model: BaseModel = MODELS.get(self.tm_model.model_name)(**kwargs)
self.available = [True] * instance_num
self.starts = [None] * instance_num
self.steps = {}
self.loop = asyncio.get_event_loop()
def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
......@@ -61,8 +62,24 @@ class AsyncEngine:
pass
self.available[instance_id] = True
def end_session(self, session_id: int):
"""Clear a session by a session_id."""
instance_id = session_id % self.instance_num
input_ids = self.tokenizer.encode('')
for outputs in self.generators[instance_id].stream_infer(
session_id,
input_ids,
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
pass
self.steps[str(session_id)] = 0
self.available[instance_id] = True
@contextmanager
def safe_run(self, instance_id: int, session_id: Optional[int] = None):
"""A context manager to make sure server's safe running."""
self.available[instance_id] = False
try:
yield
......@@ -142,7 +159,7 @@ class AsyncEngine:
session_id,
stream_response=True,
sequence_start=True,
sequence_end=False,
sequence_end=True, # no interactive mode by default
step=0,
request_output_len=512,
stop=False,
......@@ -151,6 +168,7 @@ class AsyncEngine:
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
do_preprocess=True,
):
"""Generate responses.
......@@ -172,6 +190,7 @@ class AsyncEngine:
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
do_preprocess (bool): whether pre-process the messages.
"""
instance_id = session_id % self.instance_num
if str(session_id) not in self.steps:
......@@ -179,14 +198,18 @@ class AsyncEngine:
if step != 0:
self.steps[str(session_id)] = step
seed = random.getrandbits(64)
prompt = self.model.messages2prompt(messages, sequence_start)
prompt = messages
if do_preprocess:
prompt = self.model.messages2prompt(prompt, sequence_start)
input_ids = self.tokenizer.encode(prompt)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
input_ids) >= self.tm_model.session_len:
input_ids) + request_output_len >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason)
if sequence_end is True and sequence_start is False:
self.end_session(session_id)
else:
generator = await self.get_generator(instance_id, stop)
with self.safe_run(instance_id, session_id):
......@@ -225,98 +248,3 @@ class AsyncEngine:
self.steps[str(session_id)] += len(input_ids) + tokens
if sequence_end or stop:
self.steps[str(session_id)] = 0
async def generate_openai(
self,
messages,
instance_id,
stream_response=True,
renew_session=False,
request_output_len=512,
stop=False,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
):
"""Generate responses.
Args:
messages (str | List): chat history or prompt
instance_id (int): actually request host ip
stream_response (bool): whether return responses streamingly
renew_session (bool): renew the session
request_output_len (int): output token nums
stop (bool): whether stop inference
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
"""
session_id = instance_id
instance_id %= self.instance_num
sequence_start = False
generator = await self.get_generator(instance_id)
if renew_session: # renew a session
empty_input_ids = self.tokenizer.encode('')
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[empty_input_ids],
request_output_len=0,
sequence_start=False,
sequence_end=True,
stop=True):
pass
self.steps[str(session_id)] = 0
if str(session_id) not in self.steps:
self.steps[str(session_id)] = 0
if self.steps[str(session_id)] == 0:
sequence_start = True
seed = random.getrandbits(64)
prompt = self.model.messages2prompt(messages, sequence_start)
input_ids = self.tokenizer.encode(prompt)
finish_reason = 'stop' if stop else None
if self.steps[str(session_id)] + len(
input_ids) >= self.tm_model.session_len:
finish_reason = 'length'
yield GenOut('', self.steps[str(session_id)], len(input_ids), 0,
finish_reason)
else:
with self.safe_run(instance_id, session_id):
response_size = 0
async for outputs in generator.async_stream_infer(
session_id=session_id,
input_ids=[input_ids],
stream_output=stream_response,
request_output_len=request_output_len,
sequence_start=(sequence_start),
sequence_end=False,
step=self.steps[str(session_id)],
stop=stop,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos,
random_seed=seed if sequence_start else None):
res, tokens = outputs[0]
# decode res
response = self.tokenizer.decode(res.tolist(),
offset=response_size)
# utf-8 char at the end means it's a potential unfinished
# byte sequence, continue to concate it with the next
# sequence and decode them together
if response.endswith('�'):
continue
# response, history len, input len, generation len
yield GenOut(response, self.steps[str(session_id)],
len(input_ids), tokens, finish_reason)
response_size = tokens
# update step
self.steps[str(session_id)] += len(input_ids) + tokens
# Copyright (c) OpenMMLab. All rights reserved.
from .api_server_backend import run_api_server
from .triton_server_backend import run_triton_server
from .turbomind_coupled import run_local
__all__ = ['run_api_server', 'run_triton_server', 'run_local']
# Copyright (c) OpenMMLab. All rights reserved.
import threading
import time
from typing import Sequence
import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_client import (get_model_list,
get_streaming_response)
from lmdeploy.serve.openai.api_server import ip2id
class InterFace:
api_server_url: str = None
def chat_stream_restful(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
for response, tokens, finish_reason in get_streaming_response(
instruction,
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=512,
interactive_mode=True):
if finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if tokens < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for response, tokens, finish_reason in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
interactive_mode=False):
pass
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
for out in get_streaming_response(
'',
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
stop=True):
pass
time.sleep(0.5)
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
for out in get_streaming_response(
messages,
f'{InterFace.api_server_url}/v1/chat/interactive',
session_id=session_id,
request_output_len=0,
interactive_mode=True):
pass
yield (state_chatbot, disable_btn, enable_btn)
def run_api_server(api_server_url: str,
server_name: str = 'localhost',
server_port: int = 6006,
batch_size: int = 32):
"""chat with AI assistant through web ui.
Args:
api_server_url (str): restufl api url
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
"""
InterFace.api_server_url = api_server_url
model_names = get_model_list(f'{api_server_url}/v1/models')
model_name = ''
if isinstance(model_names, list) and len(model_names) > 0:
model_name = model_names[0]
else:
raise ValueError('gradio can find a suitable model from restful-api')
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_restful,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_restful_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_restful_func,
[instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import gradio as gr
CSS = """
#container {
width: 95%;
......@@ -16,3 +18,11 @@ CSS = """
margin-left: 0.5em
}
"""
THEME = gr.themes.Soft(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.sky,
font=[gr.themes.GoogleFont('Inconsolata'), 'Arial', 'sans-serif'])
enable_btn = gr.Button.update(interactive=True)
disable_btn = gr.Button.update(interactive=False)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import threading
from functools import partial
from typing import Sequence
import gradio as gr
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
from lmdeploy.serve.turbomind.chatbot import Chatbot
def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot,
cancel_btn: gr.Button, reset_btn: gr.Button,
request: gr.Request):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
llama_chatbot (Chatbot): the instance of a chatbot
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
"""
instruction = state_chatbot[-1][0]
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_response = llama_chatbot.stream_infer(
session_id, instruction, f'{session_id}-{len(state_chatbot)}')
for status, tokens, _ in bot_response:
state_chatbot[-1] = (state_chatbot[-1][0], tokens)
yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
def reset_all_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State,
llama_chatbot: gr.State, triton_server_addr: str,
model_name: str):
"""reset the session."""
state_chatbot = []
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
llama_chatbot = Chatbot(triton_server_addr,
model_name,
log_level=log_level,
display=True)
return (
llama_chatbot,
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
def cancel_func(
state_chatbot: gr.State,
llama_chatbot: gr.State,
cancel_btn: gr.Button,
reset_btn: gr.Button,
):
"""cancel the session."""
yield (llama_chatbot, state_chatbot, disable_btn, disable_btn)
session_id = llama_chatbot._session.session_id
llama_chatbot.cancel(session_id)
yield (llama_chatbot, state_chatbot, disable_btn, enable_btn)
def add_instruction(instruction, state_chatbot):
state_chatbot = state_chatbot + [(instruction, None)]
return ('', state_chatbot)
def run_triton_server(triton_server_addr: str,
server_name: str = 'localhost',
server_port: int = 6006):
"""chat with AI assistant through web ui.
Args:
triton_server_addr (str): the communication address of inference server
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
"""
with gr.Blocks(css=CSS, theme=THEME) as demo:
log_level = os.environ.get('SERVICE_LOG_LEVEL', 'INFO')
llama_chatbot = gr.State(
Chatbot(triton_server_addr, log_level=log_level, display=True))
state_chatbot = gr.State([])
model_name = llama_chatbot.value.model_name
reset_all = partial(reset_all_func,
model_name=model_name,
triton_server_addr=triton_server_addr)
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
add_instruction, [instruction_txtbox, state_chatbot],
[instruction_txtbox, state_chatbot]).then(
chat_stream,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
cancel_btn.click(cancel_func,
[state_chatbot, llama_chatbot, cancel_btn, reset_btn],
[llama_chatbot, chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(
reset_all, [instruction_txtbox, state_chatbot, llama_chatbot],
[llama_chatbot, state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=4, max_size=100, api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from typing import Sequence
import gradio as gr
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.serve.openai.api_server import ip2id
class InterFace:
async_engine: AsyncEngine = None
async def chat_stream_local(
instruction: str,
state_chatbot: Sequence,
cancel_btn: gr.Button,
reset_btn: gr.Button,
request: gr.Request,
):
"""Chat with AI assistant.
Args:
instruction (str): user's prompt
state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
"""
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
bot_summarized_response = ''
state_chatbot = state_chatbot + [(instruction, None)]
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
async for outputs in InterFace.async_engine.generate(
instruction,
session_id,
stream_response=True,
sequence_start=(len(state_chatbot) == 1),
sequence_end=False):
response = outputs.response
if outputs.finish_reason == 'length':
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
if outputs.generate_token_len < 0:
gr.Warning('WARNING: running on the old session.'
' Please restart the session by reset button.')
if state_chatbot[-1][-1] is None:
state_chatbot[-1] = (state_chatbot[-1][0], response)
else:
state_chatbot[-1] = (state_chatbot[-1][0],
state_chatbot[-1][1] + response
) # piece by piece
yield (state_chatbot, state_chatbot, enable_btn, disable_btn,
f'{bot_summarized_response}'.strip())
yield (state_chatbot, state_chatbot, disable_btn, enable_btn,
f'{bot_summarized_response}'.strip())
async def reset_local_func(instruction_txtbox: gr.Textbox,
state_chatbot: gr.State, request: gr.Request):
"""reset the session.
Args:
instruction_txtbox (str): user's prompt
state_chatbot (Sequence): the chatting history
request (gr.Request): the request from a user
"""
state_chatbot = []
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=1,
stream_response=True,
sequence_start=False,
sequence_end=True):
pass
return (
state_chatbot,
state_chatbot,
gr.Textbox.update(value=''),
)
async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button,
reset_btn: gr.Button, request: gr.Request):
"""stop the session.
Args:
state_chatbot (Sequence): the chatting history
cancel_btn (bool): enable the cancel button or not
reset_btn (bool): enable the reset button or not
request (gr.Request): the request from a user
"""
yield (state_chatbot, disable_btn, disable_btn)
session_id = threading.current_thread().ident
if request is not None:
session_id = ip2id(request.kwargs['client']['host'])
# end the session
async for out in InterFace.async_engine.generate('',
session_id,
request_output_len=0,
stream_response=True,
sequence_start=False,
sequence_end=False,
stop=True):
pass
messages = []
for qa in state_chatbot:
messages.append(dict(role='user', content=qa[0]))
if qa[1] is not None:
messages.append(dict(role='assistant', content=qa[1]))
async for out in InterFace.async_engine.generate(messages,
session_id,
request_output_len=0,
stream_response=True,
sequence_start=True,
sequence_end=False):
pass
yield (state_chatbot, disable_btn, enable_btn)
def run_local(model_path: str,
server_name: str = 'localhost',
server_port: int = 6006,
batch_size: int = 4,
tp: int = 1):
"""chat with AI assistant through web ui.
Args:
model_path (str): the path of the deployed model
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
"""
InterFace.async_engine = AsyncEngine(model_path=model_path,
instance_num=batch_size,
tp=tp)
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy Playground')
chatbot = gr.Chatbot(
elem_id='chatbot',
label=InterFace.async_engine.tm_model.model_name)
instruction_txtbox = gr.Textbox(
placeholder='Please input the instruction',
label='Instruction')
with gr.Row():
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
send_event = instruction_txtbox.submit(
chat_stream_local,
[instruction_txtbox, state_chatbot, cancel_btn, reset_btn],
[state_chatbot, chatbot, cancel_btn, reset_btn])
instruction_txtbox.submit(
lambda: gr.Textbox.update(value=''),
[],
[instruction_txtbox],
)
cancel_btn.click(cancel_local_func,
[state_chatbot, cancel_btn, reset_btn],
[state_chatbot, cancel_btn, reset_btn],
cancels=[send_event])
reset_btn.click(reset_local_func, [instruction_txtbox, state_chatbot],
[state_chatbot, chatbot, instruction_txtbox],
cancels=[send_event])
print(f'server is gonna mount on: http://{server_name}:{server_port}')
demo.queue(concurrency_count=batch_size, max_size=100,
api_open=True).launch(
max_threads=10,
share=True,
server_port=server_port,
server_name=server_name,
)
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Iterable, List
from typing import Any, Dict, Iterable, List, Optional, Union
import requests
......@@ -14,13 +14,306 @@ def get_model_list(api_url: str):
return None
class APIClient:
"""Chatbot for LLaMA series models with turbomind as inference engine.
Args:
api_server_url (str): communicating address 'http://<ip>:<port>' of
api_server
"""
def __init__(self, api_server_url: str, **kwargs):
self.api_server_url = api_server_url
self.chat_intractive_v1_url = f'{api_server_url}/v1/chat/interactive'
self.chat_completions_v1_url = f'{api_server_url}/v1/chat/completions'
self.completions_v1_url = f'{api_server_url}/v1/completions'
self.models_v1_url = f'{api_server_url}/v1/models'
self._available_models = None
@property
def available_models(self):
"""Show available models."""
if self._available_models is not None:
return self._available_models
response = requests.get(self.models_v1_url)
if hasattr(response, 'text'):
model_list = json.loads(response.text)
model_list = model_list.pop('data', [])
self._available_models = [item['id'] for item in model_list]
return self._available_models
return None
def chat_completions_v1(self,
model: str,
messages: Union[str, List[Dict[str, str]]],
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 1.0,
n: Optional[int] = 1,
max_tokens: Optional[int] = 512,
stop: Optional[bool] = False,
stream: Optional[bool] = False,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
user: Optional[str] = None,
repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False,
**kwargs):
"""Chat completion v1.
Args:
model: model name. Available from self.available_models.
messages: string prompt or chat history in OpenAI format.
temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
max_tokens (int): output token nums
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
Yields:
json objects in openai formats
"""
pload = {
k: v
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_completions_v1_url,
headers=headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
if stream:
decoded = chunk.decode('utf-8')
if decoded == 'data: [DONE]':
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
output = json.loads(decoded)
yield output
else:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
yield output
def chat_interactive_v1(self,
prompt: Union[str, List[Dict[str, str]]],
session_id: int = -1,
interactive_mode: bool = False,
stream: bool = False,
stop: bool = False,
request_output_len: int = 512,
top_p: float = 0.8,
top_k: int = 40,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
ignore_eos: bool = False,
**kwargs):
"""Interactive completions.
- On interactive mode, the chat history is kept on the server. Please
set `interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
Args:
prompt: the prompt to use for the generation.
session_id: determine which instance will be called.
If not specified with a value other than -1, using random value
directly.
interactive_mode (bool): turn on interactive mode or not. On
interactive mode, session history is kept on the server (and
vice versa).
stream: whether to stream the results or not.
stop: whether to stop the session response or not.
request_output_len (int): output token nums
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
Yields:
json objects consist of text, tokens, finish_reason
"""
pload = {
k: v
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.chat_intractive_v1_url,
headers=headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
yield output
def completions_v1(
self,
model: str,
prompt: Union[str, List[Any]],
suffix: Optional[str] = None,
temperature: Optional[float] = 0.7,
n: Optional[int] = 1,
max_tokens: Optional[int] = 16,
stream: Optional[bool] = False,
top_p: Optional[float] = 1.0,
user: Optional[str] = None,
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0,
session_id: Optional[int] = -1,
ignore_eos: Optional[bool] = False,
**kwargs):
"""Chat completion v1.
Args:
model (str): model name. Available from /v1/models.
prompt (str): the input prompt.
suffix (str): The suffix that comes after a completion of inserted
text.
max_tokens (int): output token nums
temperature (float): to modulate the next token probability
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
n (int): How many chat completion choices to generate for each
input message. Only support one here.
stream: whether to stream the results or not. Default to false.
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
user (str): A unique identifier representing your end-user.
ignore_eos (bool): indicator for ignoring eos
session_id (int): if not specified, will set random value
Yields:
json objects in openai formats
"""
pload = {
k: v
for k, v in locals().copy().items()
if k[:2] != '__' and k not in ['self']
}
headers = {'content-type': 'application/json'}
response = requests.post(self.completions_v1_url,
headers=headers,
json=pload,
stream=stream)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b'\n'):
if chunk:
if stream:
decoded = chunk.decode('utf-8')[6:]
if decoded == 'data: [DONE]':
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
output = json.loads(decoded)
yield output
else:
decoded = chunk.decode('utf-8')
output = json.loads(decoded)
yield output
def chat(self,
prompt: str,
session_id: int,
request_output_len: int = 512,
stream: bool = False,
top_p: float = 0.8,
top_k: int = 40,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
ignore_eos: bool = False):
"""Chat with a unique session_id.
Args:
prompt: the prompt to use for the generation.
session_id: determine which instance will be called.
If not specified with a value other than -1, using random value
directly.
stream: whether to stream the results or not.
stop: whether to stop the session response or not.
request_output_len (int): output token nums
top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or
higher are kept for generation.
top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
temperature (float): to modulate the next token probability
repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
ignore_eos (bool): indicator for ignoring eos
Yields:
text, tokens, finish_reason
"""
assert session_id != -1, 'please set a value other than -1'
for outputs in self.chat_interactive_v1(
prompt,
session_id=session_id,
request_output_len=request_output_len,
interactive_mode=True,
stream=stream,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=ignore_eos):
if outputs['finish_reason'] == 'length':
print('WARNING: exceed session max length.'
' Please end the session.')
yield outputs['text'], outputs['tokens'], outputs['finish_reason']
def end_session(self, session_id: int):
"""End the session with a unique session_id.
Args:
session_id: determine which instance will be called.
If not specified with a value other than -1, using random value
directly.
"""
for out in self.chat_interactive_v1(prompt='',
session_id=session_id,
request_output_len=0,
interactive_mode=False):
pass
def input_prompt():
"""Input a prompt in the consolo interface."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def get_streaming_response(prompt: str,
api_url: str,
session_id: int,
request_output_len: int = 512,
stream: bool = True,
sequence_start: bool = True,
sequence_end: bool = True,
interactive_mode: bool = False,
ignore_eos: bool = False,
stop: bool = False) -> Iterable[List[str]]:
headers = {'User-Agent': 'Test Client'}
......@@ -29,8 +322,7 @@ def get_streaming_response(prompt: str,
'stream': stream,
'session_id': session_id,
'request_output_len': request_output_len,
'sequence_start': sequence_start,
'sequence_end': sequence_end,
'interactive_mode': interactive_mode,
'ignore_eos': ignore_eos,
'stop': stop
}
......@@ -49,42 +341,23 @@ def get_streaming_response(prompt: str,
yield output, tokens, finish_reason
def input_prompt():
"""Input a prompt in the consolo interface."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
def main(restful_api_url: str, session_id: int = 0):
nth_round = 1
def main(api_server_url: str, session_id: int = 0):
api_client = APIClient(api_server_url)
while True:
prompt = input_prompt()
if prompt == 'exit':
for output, tokens, finish_reason in get_streaming_response(
'',
f'{restful_api_url}/generate',
session_id=session_id,
request_output_len=0,
sequence_start=(nth_round == 1),
sequence_end=True):
pass
exit(0)
if prompt in ['exit', 'end']:
api_client.end_session(session_id)
if prompt == 'exit':
exit(0)
else:
for output, tokens, finish_reason in get_streaming_response(
for text, tokens, finish_reason in api_client.chat(
prompt,
f'{restful_api_url}/generate',
session_id=session_id,
request_output_len=512,
sequence_start=(nth_round == 1),
sequence_end=False):
stream=True):
if finish_reason == 'length':
print('WARNING: exceed session max length.'
' Please end the session.')
continue
print(output, end='')
nth_round += 1
print(text, end='')
if __name__ == '__main__':
......
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
import random
import time
from http import HTTPStatus
from typing import AsyncGenerator, List, Optional
......@@ -13,8 +15,10 @@ from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest,
EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse,
ChatCompletionStreamResponse, ChatMessage, CompletionRequest,
CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
EmbeddingsRequest, ErrorResponse, GenerateRequest, GenerateResponse,
ModelCard, ModelList, ModelPermission, UsageInfo)
os.environ['TM_LOG_LEVEL'] = 'ERROR'
......@@ -104,9 +108,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
1.0 means no penalty
Additional arguments supported by LMDeploy:
- renew_session (bool): Whether renew the session. Can be used when the
session length is exceeded.
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
......@@ -114,20 +117,22 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
session_id = ip2id(raw_request.client.host)
if request.session_id == -1:
request.session_id = random.randint(1, 10086)
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = str(session_id)
request_id = str(request.session_id)
created_time = int(time.time())
result_generator = VariableInterface.async_engine.generate_openai(
result_generator = VariableInterface.async_engine.generate(
request.messages,
session_id,
request.session_id,
True, # always use stream to enable batching
request.renew_session,
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens if request.max_tokens else 512,
stop=request.stop,
top_p=request.top_p,
......@@ -188,7 +193,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(session_id)
VariableInterface.async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
......@@ -222,51 +227,191 @@ async def chat_completions_v1(request: ChatCompletionRequest,
return response
@app.post('/v1/embeddings')
async def create_embeddings(request: EmbeddingsRequest,
raw_request: Request = None):
"""Creates embeddings for the text."""
@app.post('/v1/completions')
async def completions_v1(request: CompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
if request.session_id == -1:
request.session_id = random.randint(1, 10086)
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
if isinstance(request.input, str):
request.input = [request.input]
data = []
token_num = 0
for i, prompt in enumerate(request.input):
embedding = await VariableInterface.async_engine.get_embeddings(prompt)
data.append({
'object': 'embedding',
'embedding': embedding,
'index': i
})
token_num += len(embedding)
return EmbeddingsResponse(
data=data,
model=request.model,
usage=UsageInfo(
prompt_tokens=token_num,
total_tokens=token_num,
completion_tokens=None,
),
).dict(exclude_none=True)
@app.post('/generate')
async def generate(request: GenerateRequest, raw_request: Request = None):
model_name = request.model
request_id = str(request.session_id)
created_time = int(time.time())
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
generators = []
for i in range(len(request.prompt)):
result_generator = VariableInterface.async_engine.generate(
request.prompt[i],
request.session_id + i,
True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens
if request.max_tokens else 512,
stop=False,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=False)
generators.append(result_generator)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.model_dump_json()
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for generator in generators:
for i in range(request.n):
choice_data = CompletionResponseStreamChoice(
index=i,
text='',
finish_reason=None,
)
chunk = CompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f'data: {data}\n\n'
async for res in generator:
response_json = create_stream_response_json(
index=0,
text=res.response,
)
yield f'data: {response_json}\n\n'
yield 'data: [DONE]\n\n'
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type='text/event-stream')
# Non-streaming response
usage = UsageInfo()
choices = []
async def _inner_call(i, generator):
final_res = None
text = ''
async for res in generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
final_res = res
text += res.response
assert final_res is not None
choice_data = CompletionResponseChoice(
index=0,
text=text,
finish_reason=final_res.finish_reason,
)
choices.append(choice_data)
total_tokens = sum([
final_res.history_token_len, final_res.input_token_len,
final_res.generate_token_len
])
usage.prompt_tokens += final_res.input_token_len
usage.completion_tokens += final_res.generate_token_len
usage.total_tokens += total_tokens
await asyncio.gather(
*[_inner_call(i, generators[i]) for i in range(len(generators))])
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
@app.post('/v1/embeddings', tags=['unsupported'])
async def create_embeddings(request: EmbeddingsRequest,
raw_request: Request = None):
"""Creates embeddings for the text."""
return create_error_response(HTTPStatus.BAD_REQUEST,
'Unsupported by turbomind.')
@app.post('/generate',
tags=['deprecated'],
description='please use /v1/chat/interactive')
@app.post('/v1/chat/interactive')
async def chat_interactive_v1(request: GenerateRequest,
raw_request: Request = None):
"""Generate completion for the request.
- On interactive mode, the chat history is kept on the server. Please set
`interactive_mode = True`.
- On normal mode, no chat history is kept on the server. Set
`interactive_mode = False`.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- session_id: determine which instance will be called. If not specified
with a value other than -1, using host ip directly.
- sequence_start (bool): indicator for starting a sequence.
- sequence_end (bool): indicator for ending a sequence
with a value other than -1, using random value directly.
- interactive_mode (bool): turn on interactive mode or not. On interactive
mode, session history is kept on the server (and vice versa).
- stream: whether to stream the results or not.
- stop: whether to stop the session response or not.
- request_output_len (int): output token nums
- step (int): the offset of the k/v cache
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
......@@ -278,15 +423,18 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
- ignore_eos (bool): indicator for ignoring eos
"""
if request.session_id == -1:
session_id = ip2id(raw_request.client.host)
request.session_id = session_id
request.session_id = random.randint(10087, 23333)
async_engine = VariableInterface.async_engine
sequence_start = async_engine.steps.get(str(request.session_id), 0) == 0
sequence_end = not request.interactive_mode
generation = VariableInterface.async_engine.generate(
generation = async_engine.generate(
request.prompt,
request.session_id,
stream_response=True, # always use stream to enable batching
sequence_start=request.sequence_start,
sequence_end=request.sequence_end,
sequence_start=sequence_start,
sequence_end=sequence_end,
request_output_len=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
......@@ -315,7 +463,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
async for out in generation:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
VariableInterface.async_engine.stop_session(session_id)
async_engine.stop_session(request.session_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
'Client disconnected')
text += out.response
......@@ -326,14 +474,15 @@ async def generate(request: GenerateRequest, raw_request: Request = None):
def main(model_path: str,
server_name: str = 'localhost',
server_name: str = '0.0.0.0',
server_port: int = 23333,
instance_num: int = 32,
tp: int = 1,
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*']):
allow_headers: List[str] = ['*'],
**kwargs):
"""An example to perform model inference through the command line
interface.
......@@ -359,7 +508,8 @@ def main(model_path: str,
VariableInterface.async_engine = AsyncEngine(model_path=model_path,
instance_num=instance_num,
tp=tp)
tp=tp,
**kwargs)
uvicorn.run(app=app, host=server_name, port=server_port, log_level='info')
......
......@@ -70,7 +70,9 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
renew_session: Optional[bool] = False
session_id: Optional[int] = -1
renew_session: Optional[
bool] = False # lagecy and useless, will be removed
ignore_eos: Optional[bool] = False
......@@ -135,6 +137,10 @@ class CompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
# additional argument of lmdeploy
repetition_penalty: Optional[float] = 1.0
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
class CompletionResponseChoice(BaseModel):
......@@ -191,8 +197,7 @@ class GenerateRequest(BaseModel):
"""Generate request."""
prompt: Union[str, List[Dict[str, str]]]
session_id: int = -1
sequence_start: bool = True
sequence_end: bool = False
interactive_mode: bool = False
stream: bool = False
stop: bool = False
request_output_len: int = 512
......
......@@ -69,8 +69,9 @@ def get_gen_param(cap,
def main(model_path,
session_id: int = 1,
cap: str = 'chat',
tp=1,
stream_output=True,
tp: int = 1,
stream_output: bool = True,
request_output_len: int = 512,
**kwargs):
"""An example to perform model inference through the command line
interface.
......@@ -106,12 +107,13 @@ def main(model_path,
elif prompt == 'end':
prompt = model.get_prompt('', nth_round == 1)
input_ids = tokenizer.encode(prompt)
for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids],
request_output_len=512,
sequence_start=False,
sequence_end=True,
stream_output=stream_output):
for outputs in generator.stream_infer(
session_id=session_id,
input_ids=[input_ids],
request_output_len=request_output_len,
sequence_start=False,
sequence_end=True,
stream_output=stream_output):
pass
nth_round = 1
step = 0
......@@ -119,13 +121,14 @@ def main(model_path,
else:
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
if step + len(input_ids) >= tm_model.session_len:
if step + len(
input_ids) + request_output_len >= tm_model.session_len:
print('WARNING: exceed session max length.'
' Please end the session.')
continue
gen_param = get_gen_param(cap, model.sampling_param, nth_round,
step, **kwargs)
step, request_output_len, **kwargs)
print(f'{prompt} ', end='', flush=True)
response_size = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment