Unverified Commit 6b00f623 authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Support loading hf model directly (#685)

* turbomind support export model params

* fix overflow

* support turbomind.from_pretrained

* fix tp

* support AutoModel

* support load kv qparams

* update auto_awq

* udpate docstring

* export lmdeploy version

* update doc

* remove download_hf_repo

* LmdeployForCausalLM -> LmdeployForCausalLM

* refactor turbomind.py

* update comment

* add bfloat16 convert back

* support gradio run_locl load hf

* support resuful api server load hf

* add docs

* support loading previous quantized model

* adapt pr 690

* udpate docs

* not export turbomind config when quantize a model

* check model_name when can not get it from config.json

* update readme

* remove model_name in auto_awq

* update

* update

* udpate

* fix build

* absolute import
parent 42e57c8b
......@@ -58,6 +58,7 @@ work_dir*/
*.bin
*config.json
*generate_config.json
!lmdeploy/turbomind/hf_repo/config.json
# Pytorch
*.pth
......
......@@ -20,6 +20,7 @@ ______________________________________________________________________
## News 🎉
- \[2023/11\] Turbomind supports loading hf model directly. Click [here](./docs/en/load_hf.md) for details.
- \[2023/11\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75
- \[2023/09\] TurboMind supports Qwen-14B
- \[2023/09\] TurboMind supports InternLM-20B
......@@ -114,30 +115,18 @@ pip install lmdeploy
### Deploy InternLM
#### Get InternLM model
To use TurboMind inference engine, you need to first convert the model into TurboMind format. Currently, we support online conversion and offline conversion. With online conversion, TurboMind can load the Huggingface model directly. While with offline conversion, you should save the converted model first before using it.
```shell
# 1. Download InternLM model
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b
# if you want to clone without large files – just their pointers
# prepend your git clone with the following env var:
GIT_LFS_SKIP_SMUDGE=1
# 2. Convert InternLM model to turbomind's format, which will be in "./workspace" by default
lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b
```
The following use [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) as a example to show how to use turbomind with online conversion. You can refer to [load_hf.md](docs/en/load_hf.md) for other methods.
#### Inference by TurboMind
```shell
lmdeploy chat turbomind ./workspace
lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b
```
> **Note**<br /> The internlm/internlm-chat-7b-v1_1 model will be downloaded under `.cache` folder. You can also use a local path here.
> **Note**<br />
> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. <br />
> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
......@@ -152,7 +141,7 @@ lmdeploy chat turbomind ./workspace
# install lmdeploy with extra dependencies
pip install lmdeploy[serve]
lmdeploy serve gradio ./workspace
lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
......@@ -165,13 +154,13 @@ Launch inference server by:
# install lmdeploy with extra dependencies
pip install lmdeploy[serve]
lmdeploy serve api_server ./workspace --instance_num 32 --tp 1
lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1
```
Then, you can communicate with it by command line,
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client api_server_url
```
......@@ -186,29 +175,6 @@ lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port
Refer to [restful_api.md](docs/en/restful_api.md) for more details.
#### Serving with Triton Inference Server
Launch inference server by:
```shell
bash workspace/service_docker_up.sh
```
Then, you can communicate with the inference server by command line,
```shell
python3 -m pip install tritonclient[grpc]
lmdeploy serve triton_client {server_ip_addresss}:33337
```
or webui,
```shell
lmdeploy serve gradio {server_ip_addresss}:33337
```
For the deployment of other supported models, such as LLaMA, LLaMA-2, vicuna and so on, you can find the guide from [here](docs/en/serving.md)
### Inference with PyTorch
For detailed instructions on Inference pytorch models, see [here](docs/en/pytorch.md).
......
......@@ -20,6 +20,7 @@ ______________________________________________________________________
## 更新 🎉
- \[2023/11\] Turbomind 支持直接读取 Huggingface 模型。点击[这里](./docs/en/load_hf.md)查看使用方法
- \[2023/11\] TurboMind 重磅升级。包括:Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16
- \[2023/09\] TurboMind 支持 Qwen-14B
- \[2023/09\] TurboMind 支持 InternLM-20B 模型
......@@ -114,30 +115,18 @@ pip install lmdeploy
### 部署 InternLM
#### 获取 InternLM 模型
使用 TurboMind 推理模型需要先将模型转化为 TurboMind 的格式,目前支持在线转换和离线转换两种形式。在线转换可以直接加载 Huggingface 模型,离线转换需需要先保存模型再加载。
```shell
# 1. 下载 InternLM 模型
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/internlm/internlm-chat-7b-v1_1 /path/to/internlm-chat-7b
# if you want to clone without large files – just their pointers
# prepend your git clone with the following env var:
GIT_LFS_SKIP_SMUDGE=1
# 2. 转换为 trubomind 要求的格式。默认存放路径为 ./workspace
lmdeploy convert internlm-chat-7b /path/to/internlm-chat-7b
```
下面以 [internlm/internlm-chat-7b-v1_1](https://huggingface.co/internlm/internlm-chat-7b-v1_1) 为例,展示在线转换的使用方式。其他方式可参考[load_hf.md](docs/zh_cn/load_hf.md)
#### 使用 turbomind 推理
```shell
lmdeploy chat turbomind ./workspace
lmdeploy chat turbomind internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b
```
> **Note**<br /> internlm/internlm-chat-7b-v1_1 会自动下载到 `.cache` 文件夹,这里也可以传下载好的路径。
> **Note**<br />
> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。<br />
> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
......@@ -151,7 +140,7 @@ lmdeploy chat turbomind ./workspace
# 安装lmdeploy额外依赖
pip install lmdeploy[serve]
lmdeploy serve gradio ./workspace
lmdeploy serve gradio internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b
```
![](https://github.com/InternLM/lmdeploy/assets/67539920/08d1e6f2-3767-44d5-8654-c85767cec2ab)
......@@ -164,13 +153,13 @@ lmdeploy serve gradio ./workspace
# 安装lmdeploy额外依赖
pip install lmdeploy[serve]
lmdeploy serve api_server ./workspace --server_name 0.0.0.0 --server_port ${server_port} --instance_num 32 --tp 1
lmdeploy serve api_server internlm/internlm-chat-7b-v1_1 --model-name internlm-chat-7b --instance_num 32 --tp 1
```
你可以通过命令行方式与推理服务进行对话:
```shell
# restful_api_url is what printed in api_server.py, e.g. http://localhost:23333
# api_server_url is what printed in api_server.py, e.g. http://localhost:23333
lmdeploy serve api_client api_server_url
```
......@@ -185,29 +174,6 @@ lmdeploy serve gradio api_server_url --server_name ${gradio_ui_ip} --server_port
更多详情可以查阅 [restful_api.md](docs/zh_cn/restful_api.md)
#### 通过容器部署推理服务
使用下面的命令启动推理服务:
```shell
bash workspace/service_docker_up.sh
```
你可以通过命令行方式与推理服务进行对话:
```shell
python3 -m pip install tritonclient[grpc]
lmdeploy serve triton_client {server_ip_addresss}:33337
```
也可以通过 WebUI 方式来对话:
```shell
lmdeploy serve gradio {server_ip_addresss}:33337
```
其他模型的部署方式,比如 LLaMA,LLaMA-2,vicuna等等,请参考[这里](docs/zh_cn/serving.md)
### 基于 PyTorch 的推理
你必须确保环境中有安装 deepspeed:
......
# Load huggingface model directly
Starting from v0.1.0, Turbomind adds the ability to pre-process the model parameters on-the-fly while loading them from huggingface style models.
## Supported model type
Currently, Turbomind support loading three types of model:
1. A lmdeploy-quantized model hosted on huggingface.co, such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc.
2. Other LM models on huggingface.co like Qwen/Qwen-7B-Chat
3. A model converted by `lmdeploy convert`, legacy format
## Usage
### 1) A lmdeploy-quantized model
For models quantized by `lmdeploy.lite` such as [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit), etc.
```
repo_id=internlm/internlm-chat-20b-4bit
model_name=internlm-chat-20b
# or
# repo_id=/path/to/downloaded_model
# Inference by TurboMind
lmdeploy chat turbomind $repo_id --model-name $model_name
# Serving with gradio
lmdeploy serve gradio $repo_id --model-name $model_name
# Serving with Restful API
lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1
```
### 2) Other LM models
For other LM models such as Qwen/Qwen-7B-Chat or baichuan-inc/Baichuan2-7B-Chat. LMDeploy supported models can be viewed through `lmdeploy list`.
```
repo_id=Qwen/Qwen-7B-Chat
model_name=qwen-7b
# or
# repo_id=/path/to/Qwen-7B-Chat/local_path
# Inference by TurboMind
lmdeploy chat turbomind $repo_id --model-name $model_name
# Serving with gradio
lmdeploy serve gradio $repo_id --model-name $model_name
# Serving with Restful API
lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1
```
### 3) A model converted by `lmdeploy convert`
The usage is like previous
```
# Convert a model
lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME
# Inference by TurboMind
lmdeploy chat turbomind ./workspace
# Serving with gradio
lmdeploy serve gradio ./workspace
# Serving with Restful API
lmdeploy serve api_server ./workspace --instance_num 32 --tp 1
```
# 直接读取 huggingface 模型
从 v0.1.0 开始,Turbomid 添加了直接读取 Huggingface 格式权重的能力。
## 支持的类型
目前,TurboMind 支持加载三种类型的模型:
1. 在 huggingface.co 上面通过 lmdeploy 量化的模型,如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit)
2. huggingface.co 上面其他 LM 模型,如Qwen/Qwen-7B-Chat
3. 通过 `lmdeploy convert` 命令转换好的模型,兼容旧格式
## 使用方式
### 1) 通过 lmdeploy 量化的模型
对于通过 `lmdeploy.lite` 量化的模型,TurboMind 可以直接加载,比如 [llama2-70b-4bit](https://huggingface.co/lmdeploy/llama2-chat-70b-4bit), [internlm-chat-20b-4bit](https://huggingface.co/internlm/internlm-chat-20b-4bit).
```
repo_id=internlm/internlm-chat-20b-4bit
model_name=internlm-chat-20b
# or
# repo_id=/path/to/downloaded_model
# Inference by TurboMind
lmdeploy chat turbomind $repo_id --model-name $model_name
# Serving with gradio
lmdeploy serve gradio $repo_id --model-name $model_name
# Serving with Restful API
lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1
```
### 2) 其他的 LM 模型
其他 LM 模型比如 Qwen/Qwen-7B-Chat, baichuan-inc/Baichuan2-7B-Chat。LMDeploy 模型支持情况可通过 `lmdeploy list` 查看。
```
repo_id=Qwen/Qwen-7B-Chat
model_name=qwen-7b
# or
# repo_id=/path/to/Qwen-7B-Chat/local_path
# Inference by TurboMind
lmdeploy chat turbomind $repo_id --model-name $model_name
# Serving with gradio
lmdeploy serve gradio $repo_id --model-name $model_name
# Serving with Restful API
lmdeploy serve api_server $repo_id --model-name $model_name --instance_num 32 --tp 1
```
### 3) 通过 `lmdeploy convert` 命令转换好的模型
使用方式与之前相同
```
# Convert a model
lmdeploy convert /path/to/model ./workspace --model-name MODEL_NAME
# Inference by TurboMind
lmdeploy chat turbomind ./workspace
# Serving with gradio
lmdeploy serve gradio ./workspace
# Serving with Restful API
lmdeploy serve api_server ./workspace --instance_num 32 --tp 1
```
......@@ -11,7 +11,7 @@ class SubCliServe(object):
server_port: int = 6006,
batch_size: int = 32,
tp: int = 1,
restful_api: bool = False):
**kwargs):
"""Serve LLMs with web ui using gradio.
Example 1:
......@@ -21,7 +21,6 @@ class SubCliServe(object):
lmdeploy serve gradio http://0.0.0.0:23333
--server_name 0.0.0.0
--server_port 6006
--restful_api True
Example 3:
lmdeploy serve gradio ${triton_server_ip_addresss}:33337
......@@ -30,13 +29,12 @@ class SubCliServe(object):
model_path_or_server (str): the path of the deployed model or the
tritonserver URL or restful api URL. The former is for directly
running service with gradio. The latter is for running with
tritonserver by default. If the input URL is restful api.
Please enable another flag `restful_api`.
tritonserver by default.
server_name (str): the ip address of gradio server
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
restful_api (bool): a flag for model_path_or_server
kwargs (dict): extra params to init
"""
from lmdeploy.serve.gradio.app import run
run(model_path_or_server,
......@@ -44,7 +42,7 @@ class SubCliServe(object):
server_port=server_port,
batch_size=batch_size,
tp=tp,
restful_api=restful_api)
**kwargs)
def api_server(self,
model_path: str,
......@@ -55,7 +53,8 @@ class SubCliServe(object):
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*']):
allow_headers: List[str] = ['*'],
**kwargs):
"""Serve LLMs with restful api using fastapi.
Args:
......@@ -68,6 +67,7 @@ class SubCliServe(object):
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
kwargs (dict) extra params to init api server
"""
from lmdeploy.serve.openai.api_server import main as run_api_server
......@@ -79,7 +79,8 @@ class SubCliServe(object):
allow_origins=allow_origins,
allow_credentials=allow_credentials,
allow_methods=allow_methods,
allow_headers=allow_headers)
allow_headers=allow_headers,
**kwargs)
def api_client(self, restful_api_url: str, session_id: int = 0):
"""Interact with restful api server in terminal.
......
......@@ -10,6 +10,8 @@ from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
quant_weights, smooth_layers)
from lmdeploy.lite.utils import collect_target_modules, load_hf_from_pretrained
# from lmdeploy.lite.utils.export_turbomind import export_turbomind_config
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
......@@ -33,6 +35,9 @@ def auto_awq(model: str,
w_group_size: int = 128,
device: str = 'cuda'):
assert model != work_dir, '$WORK_DIR and $HF_MODEL should be different'
model_path = model # noqa
# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
......@@ -61,6 +66,11 @@ def auto_awq(model: str,
model.save_pretrained(work_dir, max_shard_size='2GB')
tokenizer.save_pretrained(work_dir)
# export_turbomind_config(model_name,
# model_path,
# work_dir,
# group_size=w_group_size)
if __name__ == '__main__':
import fire
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
from pathlib import Path
from typing import Union
......@@ -6,11 +7,28 @@ import numpy as np
import torch
def _export_weight(into: str,
kv_qparams: np.array,
out_path: str,
tm_params: dict = None):
"""Save kv_qparams to disk or copy to tm_params."""
if tm_params is None:
print(into)
kv_qparams.tofile(out_path)
else:
name = os.path.basename(out_path)
src = torch.from_numpy(kv_qparams)
for tm_tensor in tm_params[name]:
tm_tensor.copy_from(src)
tm_params.pop(name)
def _export_sym(key_stats: dict,
value_stats: dict,
bits: int,
out_dir: Union[str, Path],
tp: int = 1) -> None:
tp: int = 1,
tm_params: dict = None) -> None:
"""Export symmetric quantization parameters to specified directory."""
keys_absmax = key_stats['absmax']
values_absmax = value_stats['absmax']
......@@ -31,15 +49,16 @@ def _export_sym(key_stats: dict,
kv_qparams = np.array([k_s, v_s], dtype=np.float32)
out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501
kv_qparams.tofile(out_path)
print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}')
info = f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}'
_export_weight(info, kv_qparams, out_path, tm_params)
def _export_asym(key_stats: dict,
value_stats: dict,
bits: int,
out_dir: Union[str, Path],
tp: int = 1) -> None:
tp: int = 1,
tm_params: dict = None) -> None:
"""Export asymmetric quantization parameters to specified directory."""
keys_min = key_stats['min']
values_min = value_stats['min']
......@@ -81,16 +100,17 @@ def _export_asym(key_stats: dict,
kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp],
dtype=np.float32)
out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight'
kv_qparams.tofile(out_path)
print(f'Layer {layer_idx} MP {i} qparam: '
f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}')
info = f'Layer {layer_idx} MP {i} qparam: ' \
f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}'
_export_weight(info, kv_qparams, out_path, tm_params)
def main(work_dir: str,
turbomind_dir: str,
kv_bits: int = 8,
kv_sym: bool = False,
num_tp: int = 1) -> None:
num_tp: int = 1,
tm_params: dict = None) -> None:
"""Main function to export key and value stats.
Args:
......@@ -102,6 +122,7 @@ def main(work_dir: str,
kv_sym (bool, optional): Whether to use symmetric quantizaiton.
Defaults to False.
num_tp (int, optional): Number of tensor parallelism. Defaults to 1.
tm_params (dict): turbomind model weights.
"""
work_dir = Path(work_dir)
......@@ -113,9 +134,10 @@ def main(work_dir: str,
value_stats = torch.load(work_dir / 'value_stats.pth')
if kv_sym:
_export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
_export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp, tm_params)
else:
_export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
_export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp,
tm_params)
if __name__ == '__main__':
......
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import shutil
from huggingface_hub import snapshot_download
from lmdeploy.turbomind.utils import get_hf_config_content
def export_turbomind_config(model_name: str,
model_path: str,
work_dir: str,
model_format: str = 'awq',
group_size: int = 128,
tp: int = 1):
"""Export hf lmdeploy model and config.json."""
import lmdeploy
from lmdeploy.model import MODELS
from lmdeploy.turbomind.deploy.converter import get_model_format
from lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS
from lmdeploy.turbomind.deploy.target_model.base import (
OUTPUT_MODELS, TurbomindModelConfig)
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
if not os.path.exists(model_path):
model_path = snapshot_download(model_path, local_files_only=True)
lmdeploy_dir = os.path.split(lmdeploy.__file__)[0]
hf_repo = os.path.join(lmdeploy_dir, 'turbomind', 'hf_repo')
files = os.listdir(hf_repo)
for file in files:
src = os.path.join(hf_repo, file)
dst = os.path.join(work_dir, file)
shutil.copy(src, dst)
cfg = TurbomindModelConfig.from_dict({}, allow_none=True)
cfg.model_name = model_name
cfg.tensor_para_size = tp
cfg.rotary_embedding = cfg.size_per_head
cfg.group_size = group_size
cfg.weight_type = 'int4'
output_format = 'w4'
inferred_model_format = get_model_format(model_name, model_format)
input_model = INPUT_MODELS.get(inferred_model_format)(
model_path=model_path, tokenizer_path=work_dir, ckpt_path=work_dir)
output_model = OUTPUT_MODELS.get(output_format)(input_model=input_model,
cfg=cfg,
to_file=False,
out_dir='')
old_data = get_hf_config_content(model_path)
config = output_model.cfg.__dict__
config_file = os.path.join(work_dir, 'config.json')
with open(config_file) as f:
data = json.load(f)
for k, v in old_data.items():
if k in data:
data[f'__{k}'] = v
else:
data[k] = v
data['turbomind'] = config
from lmdeploy.version import __version__
data['lmdeploy_version'] = __version__
with open(config_file, 'w') as f:
f.write(json.dumps(data, indent=2) + '\n')
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import dataclasses
import os.path as osp
import random
from contextlib import contextmanager
from typing import List, Literal, Optional
......@@ -28,15 +27,10 @@ class AsyncEngine:
def __init__(self, model_path, instance_num=32, tp=1, **kwargs) -> None:
from lmdeploy import turbomind as tm
from lmdeploy.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
self.tm_model = tm.TurboMind(model_path,
eos_id=tokenizer.eos_token_id,
tp=tp,
**kwargs)
self.tokenizer = tokenizer
self.tm_model = tm.TurboMind.from_pretrained(model_path,
tp=tp,
**kwargs)
self.tokenizer = self.tm_model.tokenizer
self.generators = [
self.tm_model.create_instance() for i in range(instance_num)
]
......
......@@ -32,7 +32,7 @@ def run(model_path_or_server: str,
else:
from lmdeploy.serve.gradio.turbomind_coupled import run_local
run_local(model_path_or_server, server_name, server_port, batch_size,
tp)
tp, **kwargs)
if __name__ == '__main__':
......
......@@ -118,7 +118,8 @@ def run_local(model_path: str,
server_name: str = 'localhost',
server_port: int = 6006,
batch_size: int = 4,
tp: int = 1):
tp: int = 1,
**kwargs):
"""chat with AI assistant through web ui.
Args:
......@@ -130,7 +131,8 @@ def run_local(model_path: str,
"""
InterFace.async_engine = AsyncEngine(model_path=model_path,
instance_num=batch_size,
tp=tp)
tp=tp,
**kwargs)
with gr.Blocks(css=CSS, theme=THEME) as demo:
state_chatbot = gr.State([])
......
# Copyright (c) OpenMMLab. All rights reserved.
import dataclasses
import os
import os.path as osp
import random
os.environ['TM_LOG_LEVEL'] = 'ERROR'
from lmdeploy.turbomind.utils import get_gen_param
@dataclasses.dataclass
class GenParam:
top_p: float
top_k: float
temperature: float
repetition_penalty: float
sequence_start: bool = False
sequence_end: bool = False
step: int = 0
request_output_len: int = 512
os.environ['TM_LOG_LEVEL'] = 'ERROR'
def input_prompt(model_name):
......@@ -40,30 +29,6 @@ def valid_str(string, coding='utf-8'):
return ret
def get_gen_param(cap,
sampling_param,
nth_round,
step,
request_output_len=512,
**kwargs):
"""return parameters used by token generation."""
gen_param = GenParam(**dataclasses.asdict(sampling_param),
request_output_len=request_output_len)
# Fix me later. turbomind.py doesn't support None top_k
if gen_param.top_k is None:
gen_param.top_k = 40
if cap == 'chat':
gen_param.sequence_start = (nth_round == 1)
gen_param.sequence_end = False
gen_param.step = step
else:
gen_param.sequence_start = True
gen_param.sequence_end = True
gen_param.step = 0
return gen_param
def main(model_path,
session_id: int = 1,
cap: str = 'chat',
......@@ -84,15 +49,11 @@ def main(model_path,
**kwarg (dict): other arguments for initializing model's chat template
"""
from lmdeploy import turbomind as tm
from lmdeploy.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path,
eos_id=tokenizer.eos_token_id,
tp=tp,
capability=cap,
**kwargs)
tm_model = tm.TurboMind.from_pretrained(model_path,
tp=tp,
capability=cap,
**kwargs)
tokenizer = tm_model.tokenizer
generator = tm_model.create_instance()
nth_round = 1
......
......@@ -203,7 +203,7 @@ def main(model_name: str,
if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4'
output_format = 'w4'
assert group_size > 0, 'group_size should > 0'
assert group_size > 0, f'group_size: {group_size} should > 0'
# convert
print('model_name ', model_name)
......
......@@ -64,6 +64,7 @@ class BaseReader(ABC):
for key in self.params:
layer_id = re.findall(self.attn_layer_patten, key)
if len(layer_id) == 0:
# tok, norm, output
to_remove.append(key)
else:
layer_id = int(layer_id[0])
......
......@@ -18,6 +18,9 @@ OUTPUT_MODELS = Registry(
def tprint(*args, **kwargs):
to_file = kwargs.pop('to_file', False)
if not to_file:
return
from io import StringIO
s = StringIO()
print(*args, **kwargs, file=s, end='')
......@@ -90,10 +93,13 @@ class BaseOutputModel(ABC):
out_dir: str = ''):
super().__init__()
self.input_model = input_model
self.cfg = self.get_config(cfg)
self.cfg = cfg
if not cfg.valid:
self.cfg = self.get_config(cfg)
assert self.cfg.valid
self.to_file = to_file
self.out_dir = out_dir
self.tm_params = {}
@abstractmethod
def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
......@@ -136,6 +142,27 @@ class BaseOutputModel(ABC):
tprint(name, param.shape)
param.contiguous().cpu().numpy().tofile(
osp.join(self.out_dir, name))
elif len(self.tm_params) > 0:
tm_params = self.tm_params
weight_type = self.cfg.weight_type
assert weight_type in ['fp16', 'fp32', 'int4']
# currently, the tensor type should in
# [torch.float, torch.half, torch.int32]
torch_tensor = param.cuda().contiguous()
assert torch_tensor.dtype in [
torch.int32, torch.float, torch.half, torch.bfloat16
]
if torch_tensor.dtype != torch.int32:
if weight_type in ['fp16', 'int4']:
torch_tensor = torch_tensor.half()
else:
torch_tensor = torch_tensor.float()
for tm_tensor in tm_params[name]:
tm_tensor.copy_from(torch_tensor)
tm_params.pop(name)
else:
tprint('skip export', name, param.shape)
def save_split(self,
tensor: torch.Tensor,
......@@ -145,8 +172,10 @@ class BaseOutputModel(ABC):
"""save split."""
tp = self.cfg.tensor_para_size
if split_dim is not None:
tprint(f'*** splitting {name}, shape={tensor.shape}, '
f'split_dim={split_dim}, tp={tp}')
tprint(
f'*** splitting {name}, shape={tensor.shape}, '
f'split_dim={split_dim}, tp={tp}',
to_file=self.to_file)
assert tensor.shape[split_dim] % tp == 0
split_size = tensor.shape[split_dim] // tp
splits = torch.split(tensor, split_size, dim=split_dim)
......@@ -154,7 +183,8 @@ class BaseOutputModel(ABC):
prefix, ext = osp.splitext(name)
self.export_weight(split, f'{prefix}.{i}{ext}')
elif copy:
tprint(f'### copying {name}, shape={tensor.shape}')
tprint(f'### copying {name}, shape={tensor.shape}',
to_file=self.to_file)
copies = [tensor] * tp
for i, copy in enumerate(copies):
prefix, ext = osp.splitext(name)
......@@ -166,7 +196,9 @@ class BaseOutputModel(ABC):
"""Export to turbomind model format."""
num_layer = self.cfg.num_layer
from tqdm import tqdm
pbar = tqdm(total=num_layer, desc='Convert to turbomind format')
pbar = tqdm(total=num_layer,
desc='Convert to turbomind format',
leave=self.to_file)
self.export_config()
for bin in self.input_model.bins():
self.export_misc(bin)
......
{
"architectures": [
"LMDeployForCausalLM"
],
"auto_map": {
"AutoConfig": "configuration_lmdeploy.LMDeployConfig",
"AutoModel": "modeling_lmdeploy.LMDeployForCausalLM",
"AutoModelForCausalLM": "modeling_lmdeploy.LMDeployForCausalLM"
},
"turbomind": {}
}
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from transformers import PretrainedConfig
from lmdeploy.turbomind.deploy.target_model.base import TurbomindModelConfig
from lmdeploy.version import __version__ as lm_version
class LMDeployConfig(PretrainedConfig):
"""Lmdeploy config."""
def __init__(self, turbomind: dict = None, **kwargs):
default_tm_cfg = copy.deepcopy(
TurbomindModelConfig.from_dict({}, allow_none=True).__dict__)
if turbomind is not None:
default_tm_cfg.update(turbomind)
self.turbomind = default_tm_cfg
self.lmdeploy_version = lm_version
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
config, kwargs = super().from_pretrained(pretrained_model_name_or_path,
return_unused_kwargs=True,
**kwargs)
for k, v in kwargs.items():
if k in config.turbomind.keys():
config.turbomind[k] = v
if 'tp' in kwargs:
config.turbomind['tensor_para_size'] = kwargs['tp']
if return_unused_kwargs:
return config, kwargs
else:
return config
# Copyright (c) OpenMMLab. All rights reserved.
import dataclasses
import os
from contextlib import contextmanager
from dataclasses import dataclass, field
from itertools import count
from queue import Queue
from typing import List, Optional, Tuple, Union
from huggingface_hub import snapshot_download
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from lmdeploy.turbomind import TurboMind
from lmdeploy.turbomind.utils import get_gen_param
from .configuration_lmdeploy import LMDeployConfig
logger = logging.get_logger(__name__)
@dataclass
class Session:
_count = count()
_session_id: int = None
_message: List[Tuple[str, str]] = field(default_factory=list)
_step: int = 0
_nth_round: int = 0
_error: int = 0
def __init__(self):
self._session_id = next(Session._count)
self._message = []
self._step = 0
self._nth_round = 0
@property
def session_id(self):
return self._session_id
@property
def message(self):
return self._message
@property
def step(self):
return self._step
@property
def nth_round(self):
return self._nth_round
@property
def error(self):
return self._error
class LMDeployForCausalLM(PreTrainedModel):
config_class = LMDeployConfig
def __init__(self,
config: LMDeployConfig,
*inputs,
model_path: str = None,
**kwargs):
super().__init__(config)
self.tm_model = TurboMind.from_pretrained(model_path, **kwargs)
que = Queue()
for _ in range(config.turbomind['max_batch_size']):
que.put(self.tm_model.create_instance())
self.que = que
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path,
*model_args,
config: Optional[Union[PretrainedConfig, str,
os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = 'main',
**kwargs):
"""Instantiate a LM model with turbomind backend."""
resume_download = kwargs.pop('resume_download', True)
proxies = kwargs.pop('proxies', None)
if os.path.isdir(pretrained_model_name_or_path):
local_folder = pretrained_model_name_or_path
else:
local_folder = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
cache_dir=cache_dir,
proxies=proxies,
resume_download=resume_download,
force_download=force_download,
token=token,
local_files_only=local_files_only,
)
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else local_folder
kwargs.pop('return_unused_kwargs')
config, model_kwargs = cls.config_class.from_pretrained(
config_path, return_unused_kwargs=True, **kwargs)
else:
model_kwargs = kwargs
model = cls(config,
*model_args,
model_path=local_folder,
**model_kwargs)
generation_config = model.tm_model.model.sampling_param
for k, v in dataclasses.asdict(generation_config).items():
if hasattr(model.generation_config, k):
base_value = getattr(model.generation_config, k)
setattr(generation_config, k, base_value)
if k in kwargs:
setattr(generation_config, k, v)
model.generation_config = generation_config
return model
@contextmanager
def managed_generator(self, session: Session):
generator = self.que.get()
try:
yield generator
except: # noqa E722
for _ in generator.stream_infer(session.session_id, [0],
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
session._error = 1
finally:
self.que.put(generator)
def generate(
self,
input_ids: List[int],
session: Session,
**kwargs,
):
"""Generates sequences of token ids for models with a language modeling
head.
Args:
input_ids (List(int)): list of input token ids
session (Session) session information
kwargs (dict): hoc parametrization of generation
"""
with self.managed_generator(session) as generator:
for outputs in generator.stream_infer(
session_id=session.session_id,
input_ids=[input_ids],
**kwargs,
):
res, tokens = outputs[0]
yield res, tokens
def chat(
self,
query: str,
session: Optional[Session] = None,
cap: str = 'chat',
request_output_len: int = 512,
stream_output: bool = False,
ignore_eos=False,
random_seed: Optional[int] = None,
**kwargs,
) -> Tuple[str, Session]:
"""chat."""
if session is None:
session = Session()
assert session._error == 0, 'An error occurred before, ' \
'please start a new session.'
session._message.append([query, ''])
prompt = self.tm_model.model.get_prompt(query, session.nth_round == 0)
input_ids = self.tm_model.tokenizer.encode(prompt)
if len(
input_ids
) + session.step + request_output_len >= self.tm_model.session_len:
logger.error(
f'session_length exceeded {self.tm_model.session_len}')
session._error = 1
yield '', session
else:
gen_param = get_gen_param(cap, self.generation_config,
session.nth_round + 1, session.step,
request_output_len, **kwargs)
gen_kwargs = dataclasses.asdict(gen_param)
gen_kwargs.update(
random_seed=random_seed if session.nth_round == 0 else None,
stream_output=stream_output,
ignore_eos=ignore_eos,
**kwargs)
_step = session._step
_nth_round = session._nth_round
response_size = 0
for res, tokens in self.generate(input_ids,
session=session,
**gen_kwargs):
response = self.tm_model.tokenizer.decode(res.tolist(),
offset=response_size)
if response.endswith('�'):
continue
response_size = tokens
session._message[-1][-1] += response
session._nth_round = _nth_round + 1
session._step = _step + response_size
yield response, session
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import copy
import io
import json
import logging
import os.path as osp
import sys
from configparser import ConfigParser
from contextlib import contextmanager
from queue import Queue
from threading import Thread
from typing import Iterable, List
from typing import Iterable, List, Optional
import numpy as np
import torch
from huggingface_hub import snapshot_download
from torch.nn.utils.rnn import pad_sequence
import lmdeploy
......@@ -17,19 +22,27 @@ from lmdeploy.model import MODELS, BaseModel
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger
from .deploy.converter import get_model_format, supported_formats
from .deploy.source_model.base import INPUT_MODELS
from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig
from .utils import (ModelSource, check_tm_model_input, create_hf_download_args,
get_hf_config_content, get_model_source)
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
import _turbomind as _tm # noqa: E402
logger = logging.getLogger(__name__)
def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
"""return list of stop-words to numpy.ndarray."""
if stop_words is None:
return None
assert isinstance(stop_words, List) and \
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
all(isinstance(elem, str) for elem in stop_words), \
f'stop_words must be a list but got {type(stop_words)}'
stop_words = [
tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
]
......@@ -76,77 +89,289 @@ class TurboMind:
Args:
model_path (str): the path of turbomind's model
eos_id (int): eos token id
model_source (int): model source
model_name (str): needed when model_path is a hf model and not
managed by lmdeploy
model_format (str): needed when model_path is a hf model and not
managed by lmdeploy
group_size (int): needed when model_path is a hf model and not
managed by lmdeploy
tp (int): tensor parallel
"""
def __init__(self,
model_path: str,
eos_id: int = 2,
tp: int = 1,
model_source: ModelSource = ModelSource.WORKSPACE,
model_name: Optional[str] = None,
model_format: Optional[str] = None,
group_size: Optional[int] = None,
tp: Optional[int] = None,
**kwargs):
self.eos_id = eos_id
# TODO: support mpi
node_id = 0
node_num = 1
# read meta from model path
assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
self.gpu_count = tp
data_type = 'fp16'
ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
section_name = ''
if 'turbomind' in parser:
section_name = 'turbomind'
elif 'llama' in parser:
section_name = 'llama'
if len(section_name) > 0:
tp_cfg = parser.getint(section_name, 'tensor_para_size')
if tp_cfg != 1 and tp_cfg != tp:
get_logger('turbomind').info(
f'found tp={tp_cfg} in config.ini.')
self.gpu_count = tp_cfg
self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type')
if tp is not None:
assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
self.gpu_count = tp if tp is not None else 1
if model_source == ModelSource.WORKSPACE:
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
self.tokenizer = Tokenizer(tokenizer_model_path)
self.model_comm = self._from_workspace(model_path)
else:
self.tokenizer = Tokenizer(model_path)
self.model_comm = self._from_hf(model_source=model_source,
model_path=model_path,
model_name=model_name,
model_format=model_format,
group_size=group_size,
tp=tp,
**kwargs)
self.eos_id = self.tokenizer.eos_token_id
self.model: BaseModel = MODELS.get(self.model_name)(**kwargs)
self.session_len = self.model.session_len
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
self.stop_words = _stop_words(self.model.stop_words, tokenizer)
self.stop_words = _stop_words(self.model.stop_words, self.tokenizer)
# params
self.node_id = node_id
self.node_num = node_num
self.world_size = self.node_num * self.gpu_count
def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""
# create model
weight_dir = osp.join(model_path, 'triton_models', 'weights')
model_comm = _tm.AbstractTransformerModel.create_llama_model(
weight_dir, tensor_para_size=self.gpu_count, data_type=data_type)
self.model_comm = model_comm
# TODO: support mpi
self.node_id = 0
self.node_num = 1
self.nccl_params = model_comm.create_nccl_params(self.node_id)
torch.cuda.synchronize()
# create weight
def _create_weight(device_id):
def _create_weight_func(device_id):
with cuda_ctx(device_id):
rank = self.node_id * self.gpu_count + device_id
model_comm.create_shared_weights(device_id, rank)
threads = []
for device_id in range(self.gpu_count):
t = Thread(target=_create_weight, args=(device_id, ))
t = Thread(target=_create_weight_func, args=(device_id, ))
t.start()
threads.append(t)
for t in threads:
t.join()
def _load_kv_qparams(self, model_path, tm_params, **kwargs):
"""Load kv qparams when loading from hf."""
if self.config.quant_policy:
logger.warning('loading kv_cache quant scale')
from lmdeploy.lite.apis.kv_qparams import main as kv_loader
kv_sym = kwargs.get('kv_sym', False)
kv_bits = kwargs.get('kv_bits', 8)
tp = self.config.tensor_para_size
kv_loader(model_path, model_path, kv_bits, kv_sym, tp, tm_params)
else:
for key in list(tm_params.keys()):
if 'past_kv_scale' in key:
tm_params.pop(key)
def _get_model_params(self, model_comm, tm_params):
"""Get turbomind model params when loading from hf."""
def _get_params(device_id, que):
with cuda_ctx(device_id):
rank = self.node_id * self.gpu_count + device_id
out = model_comm.get_params(device_id, rank)
que.put(out)
que = Queue()
threads = []
for device_id in range(self.gpu_count):
t = Thread(target=_get_params, args=(device_id, que))
t.start()
threads.append(t)
for t in threads:
t.join()
for _ in range(self.gpu_count):
tensor_map = que.get()
for k, v in tensor_map.items():
if k not in tm_params:
tm_params[k] = []
tm_params[k].append(v)
def _from_hf(self,
model_source: ModelSource,
model_path: str,
model_name: Optional[str] = None,
model_format: Optional[str] = None,
group_size: Optional[int] = None,
tp: Optional[int] = None,
**kwargs):
"""Load model which is in hf format."""
# get model_name, group_size if is lmdeploy managed.
if model_source == ModelSource.HF_LMDEPLOY:
config = get_hf_config_content(model_path, local_files_only=True)
tm_config = config['turbomind']
tm_config.update(kwargs)
var_shoud_be_none = dict(model_name=model_name,
model_format=model_format,
group_size=group_size)
for key, value in var_shoud_be_none.items():
assert value is None, f'{key} should be None when model is '\
f'from {model_source}'
model_name = tm_config['model_name']
group_size = tm_config['group_size']
if tm_config['weight_type'] == 'int4':
model_format = 'awq'
else:
assert model_name is not None, 'please supply model_name when ' \
f'model is form {model_source}'
if osp.exists(osp.join(model_path, 'outputs_stats.pth')):
model_format = 'awq' if model_format is None else model_format
group_size = 128 if group_size is None else group_size
tm_config = kwargs
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
f'The supported models are: {MODELS.module_dict.keys()}'
assert model_format in supported_formats, 'the model format ' \
f'should be in {supported_formats}'
data_type = 'fp16'
output_format = 'fp16'
inferred_model_format = get_model_format(model_name, model_format)
cfg = TurbomindModelConfig.from_dict(tm_config, allow_none=True)
# overwrite with input params
cfg.model_name = model_name
cfg.tensor_para_size = 1 if tp is None else tp
cfg.rotary_embedding = cfg.size_per_head
cfg.group_size = group_size
if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4'
output_format = 'w4'
data_type = 'int4'
assert group_size > 0, f'group_size: {group_size} should > 0'
self.config = cfg
self.model_name = model_name
self.data_type = data_type
input_model = INPUT_MODELS.get(inferred_model_format)(
model_path=model_path, tokenizer_path=model_path, ckpt_path=None)
output_model = OUTPUT_MODELS.get(output_format)(
input_model=input_model, cfg=cfg, to_file=False, out_dir='')
config = copy.deepcopy(output_model.cfg.__dict__)
logger.warning(f'model_config:\n{json.dumps(config, indent=2)}')
parser = ConfigParser()
parser['llama'] = config
with io.StringIO() as ss:
parser.write(ss)
ss.seek(0)
config = ss.read()
model_comm = _tm.AbstractTransformerModel.create_llama_model(
model_dir='',
config=config,
tensor_para_size=self.gpu_count,
data_type=data_type)
# create empty weight
self._create_weight(model_comm)
# copy hf model weight to turbomind weight
tm_params = output_model.tm_params
self._get_model_params(model_comm, tm_params)
logger.warning(f'get {len(tm_params)} model params')
output_model.export()
# load kv qparams
self._load_kv_qparams(model_path, tm_params, **kwargs)
assert len(tm_params) == 0, f'missing {tm_params.keys()}'
return model_comm
def _from_workspace(self, model_path: str):
"""Load model which is converted by `lmdeploy convert`"""
ini_path = osp.join(model_path, 'triton_models', 'weights',
'config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
section_name = 'llama'
tp_cfg = parser.getint(section_name, 'tensor_para_size')
if tp_cfg != 1 and tp_cfg != self.gpu_count:
get_logger('turbomind').info(
f'found tp={tp_cfg} in config.ini.')
self.gpu_count = tp_cfg
self.model_name = parser.get(section_name, 'model_name')
self.data_type = parser.get(section_name, 'weight_type')
cfg = parser._sections[section_name]
cfg = TurbomindModelConfig.from_dict(cfg)
self.config = cfg
# create model
weight_dir = osp.join(model_path, 'triton_models', 'weights')
model_comm = _tm.AbstractTransformerModel.create_llama_model(
weight_dir,
tensor_para_size=self.gpu_count,
data_type=self.data_type)
# create weight and load params
self._create_weight(model_comm)
return model_comm
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
model_name: Optional[str] = None,
model_format: Optional[str] = None,
group_size: Optional[int] = None,
tp: Optional[int] = None,
**kwargs):
"""LMDeploy's turbomind inference engine.
Args:
pretrained_model_name_or_path (str):
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii)
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when pretrained_model_name_or_path is c)
model_format (str): model format
group_size (int): group size
tp (int): tensor parallel size
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update configuration when initialize the engine.
"""
model_source = get_model_source(pretrained_model_name_or_path)
if model_source == ModelSource.WORKSPACE:
local_path = pretrained_model_name_or_path
else:
check_tm_model_input(pretrained_model_name_or_path,
model_name=model_name,
**kwargs)
if not osp.exists(pretrained_model_name_or_path):
download_kwargs = create_hf_download_args(**kwargs)
local_path = snapshot_download(pretrained_model_name_or_path,
**download_kwargs)
else:
local_path = pretrained_model_name_or_path
logger.warning(f'model_source: {model_source}')
return cls(model_source=model_source,
model_path=local_path,
model_name=model_name,
model_format=model_format,
group_size=group_size,
tp=tp,
**kwargs)
def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
......@@ -336,6 +561,7 @@ class TurboMindInstance:
tm_inputs = _np_dict_to_tm_dict(inputs)
# start forward thread
self.que = Queue()
self._forward_thread(tm_inputs)
seq_start = input_lengths + input_lengths.new_tensor(step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment