Unverified Commit 73386e21 authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Check-in user guide about turbomind config (#680)

* update

* update config guide

* update guide

* upate user guide according to review comments
parent 911c0a85
...@@ -20,6 +20,7 @@ ______________________________________________________________________ ...@@ -20,6 +20,7 @@ ______________________________________________________________________
## News 🎉 ## News 🎉
- \[2023/11\] TurboMind major upgrades, including: Paged Attention, faster attention kernels without sequence length limitation, 2x faster KV8 kernels, Split-K decoding (Flash Decoding), and W4A16 inference for sm_75
- \[2023/09\] TurboMind supports Qwen-14B - \[2023/09\] TurboMind supports Qwen-14B
- \[2023/09\] TurboMind supports InternLM-20B - \[2023/09\] TurboMind supports InternLM-20B
- \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/supported_models/codellama.md) for deployment guide - \[2023/09\] TurboMind supports all features of Code Llama: code completion, infilling, chat / instruct, and python specialist. Click [here](./docs/en/supported_models/codellama.md) for deployment guide
......
...@@ -20,6 +20,7 @@ ______________________________________________________________________ ...@@ -20,6 +20,7 @@ ______________________________________________________________________
## 更新 🎉 ## 更新 🎉
- \[2023/11\] TurboMind 重磅升级。包括:Paged Attention、更快的且不受序列最大长度限制的 attention kernel、2+倍快的 KV8 kernels、Split-K decoding (Flash Decoding) 和 支持 sm_75 架构的 W4A16
- \[2023/09\] TurboMind 支持 Qwen-14B - \[2023/09\] TurboMind 支持 Qwen-14B
- \[2023/09\] TurboMind 支持 InternLM-20B 模型 - \[2023/09\] TurboMind 支持 InternLM-20B 模型
- \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/supported_models/codellama.md)阅读部署方法 - \[2023/09\] TurboMind 支持 Code Llama 所有功能:代码续写、填空、对话、Python专项。点击[这里](./docs/zh_cn/supported_models/codellama.md)阅读部署方法
......
# TurboMind Config
TurboMind is one of the inference engines of LMDeploy. When using it to do model inference, you need to convert the input model into a TurboMind model. In the TurboMind model folder, besides model weight files, the TurboMind model also includes some other files, among which the most important is the configuration file `triton_models/weights/config.ini` that is closely related to inference performance.
If you are using LMDeploy version 0.0.x, please refer to the [turbomind 1.0 config](#turbomind-10-config) section to learn the relevant content in the configuration. Otherwise, please read [turbomind 2.0 config](#turbomind-20-config) to familiarize yourself with the configuration details.
## TurboMind 2.0 config
Take the `llama-2-7b-chat` model as an example. In TurboMind 2.0, its config.ini content is as follows:
```toml
[llama]
model_name = llama2
tensor_para_size = 1
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
session_len = 4104
weight_type = fp16
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
group_size = 0
max_batch_size = 64
max_context_token_num = 4
step_length = 1
cache_max_entry_count = 0.5
cache_block_seq_len = 128
cache_chunk_size = 1
use_context_fmha = 1
quant_policy = 0
max_position_embeddings = 2048
rope_scaling_factor = 0.0
use_logn_attn = 0
```
These parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**.
```toml
model_name = llama2
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
```
Comparing to TurboMind 1.0, the model attribute part in the config remains the same with TurboMind 1.0, while the inference parameters have changed
In the following sections, we will focus on introducing the inference parameters.
### data type
`weight_type` and `group_size` are the relevant parameters, **which cannot be modified**.
`weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included.
### batch size
The maximum batch size is still set through `max_batch_size`. But its default value has been changed from 32 to 64, and `max_batch_size` is no longer related to `cache_max_entry_count`.
### k/v cache size
k/v cache memory is determined by `cache_block_seq_len` and `cache_max_entry_count`.
TurboMind 2.0 has implemented Paged Attention, managing the k/v cache in blocks.
`cache_block_seq_len` represents the length of the token sequence in a k/v block with a default value 128. TurboMind calculates the memory size of the k/v block according to the following formula:
```
cache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type)
```
For the llama2-7b model, when storing k/v as the `half` type, the memory of a k/v block is: `128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB`
The meaning of `cache_max_entry_count` varies depending on its value:
- When it's a decimal between (0, 1), `cache_max_entry_count` represents the percentage of memory used by k/v blocks. For example, if turbomind launches on a A100-80G GPU with `cache_max_entry_count` being `0.5`, the total memory used by the k/v blocks is `80 * 0.5 = 40G`.
- When it's an integer > 0, it represents the total number of k/v blocks
The `cache_chunk_size` indicates the size of the k/v cache chunk to be allocated each time new k/v cache blocks are needed. Different values represent different meanings:
- When it is an integer > 0, `cache_chunk_size` number of k/v cache blocks are allocated.
- When the value is -1, `cache_max_entry_count` number of k/v cache blocks are allocated.
- When the value is 0, `sqrt(cache_max_entry_count)` number of k/v cache blocks are allocated.
### kv int8 switch
When initiating 8bit k/v inference, set `quant_policy = 4`. Please refer to [kv int8](./kv_int8.md) for a guide.
### long context switch
By setting `rope_scaling_factor = 1.0`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output.
Regarding the principle of Dynamic NTK, please refer to:
1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`.
## TurboMind 1.0 config
Taking the `llama-2-7b-chat` model as an example, in TurboMind 1.0, its `config.ini` content is as follows:
```toml
[llama]
model_name = llama2
tensor_para_size = 1
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
session_len = 4104
weight_type = fp16
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
group_size = 0
max_batch_size = 32
max_context_token_num = 4
step_length = 1
cache_max_entry_count = 48
cache_chunk_size = 1
use_context_fmha = 1
quant_policy = 0
max_position_embeddings = 2048
use_dynamic_ntk = 0
use_logn_attn = 0
```
These parameters are composed of model attributes and inference parameters. Model attributes include the number of layers, the number of heads, dimensions, etc., and they are **not modifiable**.
```toml
model_name = llama2
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
```
In the following sections, we will focus on introducing the inference parameters.
### data type
`weight_type` and `group_size` are the relevant parameters, **which cannot be modified**.
`weight_type` represents the data type of weights. Currently, `fp16` and `int4` are supported. `int4` represents 4bit weights. When `weight_type` is `int4`, `group_size` means the group size used when quantizing weights with `awq`. In LMDeploy prebuilt package, kernels with `group size = 128` are included.
### batch size
`max_batch_size` determines the max size of a batch during inference. In general, the larger the batch size is, the higher the throughput is. But make sure that `max_batch_size <= cache_max_entry_count`
### k/v cache size
TurboMind allocates k/v cache memory based on `session_len`, `cache_chunk_size`, and `cache_max_entry_count`.
- `session_len` denotes the maximum length of a sequence, i.e., the size of the context window.
- `cache_chunk_size` indicates the size of k/v sequences to be allocated when new sequences are added.
- `cache_max_entry_count` signifies the maximum number of k/v sequences that can be cached.
### kv int8 switch
When initiating 8bit k/v inference, change `quant_policy = 4` and `use_context_fmha = 0`. Please refer to [kv int8](./kv_int8.md) for a guide.
### long context switch
By setting `use_dynamic_ntk = 1`, you can enable the Dynamic NTK option of RoPE, which allows the model to use long-text input and output.
Regarding the principle of Dynamic NTK, please refer to:
1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
You can also turn on [LogN attention scaling](https://kexue.fm/archives/8823) by setting `use_logn_attn = 1`.
# TurboMind 配置
TurboMind 是 LMDeploy 的推理引擎,在用它推理 LLM 模型时,需要把输入模型转成 TurboMind 模型。在 TurboMind 的模型文件夹中,除模型权重外,TurboMind 模型还包括其他一些文件,其中最重要的是和推理性能息息相关的配置文件`triton_models/weights/config.ini`
如果你使用的是 LMDeploy 0.0.x 版本,请参考[turbomind 1.0 配置](#turbomind-10-配置)章节,了解配置中的相关内容。如果使用的是 LMDeploy 0.1.x 版本,请阅读[turbomind 2.0 配置](#turbomind-20-配置)了解配置细节。
## TurboMind 2.0 配置
`llama-2-7b-chat` 模型为例,在 TurboMind 2.0 中,它的`config.ini`内容如下:
```toml
[llama]
model_name = llama2
tensor_para_size = 1
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
session_len = 4104
weight_type = fp16
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
group_size = 0
max_batch_size = 64
max_context_token_num = 4
step_length = 1
cache_max_entry_count = 0.5
cache_block_seq_len = 128
cache_chunk_size = 1
use_context_fmha = 1
quant_policy = 0
max_position_embeddings = 2048
rope_scaling_factor = 0.0
use_logn_attn = 0
```
这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改**
```toml
model_name = llama2
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
```
和 TurboMind 1.0 config 相比,TurboMind 2.0 config 中的模型属性部分和 1.0 一致,但推理参数发生了变化。
在接下来的章节中,我们重点介绍推理参数。
### 数据类型
和数据类型相关的参数是 `weight_type``group_size`。它们**不可被修改**
`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`
### 批处理大小
仍通过 `max_batch_size` 设置最大批处理量。默认值由原来的 32 改成 64。
在 TurboMind 2.0 中,`max_batch_size``cache_max_entry_count`无关。
### k/v 缓存大小
`cache_block_seq_len``cache_max_entry_count` 用来调节 k/v cache 的内存大小。
TurboMind 2.0 实现了 Paged Attention,按块管理 k/v cache。
`cache_block_seq_len` 表示一块 k/v block 可以存放的 token 序列长度,默认 128。TurboMind 按照以下公式计算 k/v block 的内存大小:
```
cache_block_seq_len * num_layer * kv_head_num * size_per_head * 2 * sizeof(kv_data_type)
```
对于 llama2-7b 模型来说,以 half 类型存放 k/v 时,一块 k/v block 的内存为:`128 * 32 * 32 * 128 * 2 * sizeof(half) = 64MB`
`cache_max_entry_count` 根据取值不同,表示不同的含义:
- 当值为 (0, 1) 之间的小数时,`cache_max_entry_count` 表示 k/v block 使用的内存百分比。比如 A100-80G 显卡内存是80G,当`cache_max_entry_count`为0.5时,表示 k/v block 使用的内存总量为 80 * 0.5 = 40G
- 当值为 > 1的整数时,表示 k/v block 数量
`cache_chunk_size` 表示在每次需要新的 k/v cache 块时,开辟 k/v cache 块的大小。不同的取值,表示不同的含义:
- 当为 > 0 的整数时,开辟 `cache_chunk_size` 个 k/v cache 块
- 当值为 -1 时,开辟 `cache_max_entry_count` 个 k/v cache 块
- 当值为 0 时,时,开辟 `sqrt(cache_max_entry_count)` 个 k/v cache 块
### kv int8 开关
`quant_policy`是 KV-int8 推理开关。具体使用方法,请参考 [kv int8](./kv_int8.md) 部署文档
### 外推能力开关
默认 `rope_scaling_factor = 0` 不具备外推能力。设置为 1.0,可以开启 RoPE 的 Dynamic NTK 功能,支持长文本推理。
关于 Dynamic NTK 的原理,详细请参考:
1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)
## TurboMind 1.0 配置
`llama-2-7b-chat` 模型为例,在 TurboMind 1.0 中,它的`config.ini`内容如下:
```toml
[llama]
model_name = llama2
tensor_para_size = 1
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
session_len = 4104
weight_type = fp16
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
group_size = 0
max_batch_size = 32
max_context_token_num = 4
step_length = 1
cache_max_entry_count = 48
cache_chunk_size = 1
use_context_fmha = 1
quant_policy = 0
max_position_embeddings = 2048
use_dynamic_ntk = 0
use_logn_attn = 0
```
这些参数由模型属性和推理参数组成。模型属性包括层数、head个数、维度等等,它们**不可修改**
```toml
model_name = llama2
head_num = 32
kv_head_num = 32
vocab_size = 32000
num_layer = 32
inter_size = 11008
norm_eps = 1e-06
attn_bias = 0
start_id = 1
end_id = 2
rotary_embedding = 128
rope_theta = 10000.0
size_per_head = 128
```
在接下来的章节中,我们重点介绍推理参数。
### 数据类型
和数据类型相关的参数是 `weight_type``group_size`。它们**不可被修改**
`weight_type` 表示权重的数据类型。目前支持 fp16 和 int4。int4 表示 4bit 权重。当 `weight_type`为 4bit 权重时,`group_size` 表示 `awq` 量化权重时使用的 group 大小。目前,在 LMDeploy 的预编译包中,使用的是 `group_size = 128`
### 批处理大小
可通过`max_batch_size`调节推理时最大的 batch 数。一般,batch 越大吞吐量越高。但务必保证 `max_batch_size <= cache_max_entry_count`
### k/v cache 大小
TurboMind 根据 `session_len``cache_chunk_size``cache_max_entry_count` 开辟 k/v cache 内存。
- `session_len` 表示一个序列的最大长度,即 context window 的大小。
- `cache_chunk_size` 表示当新增对话序列时,每次要开辟多少个序列的 k/v cache
- `cache_max_entry_count` 表示最多缓存多少个对话序列
### kv int8 开关
当启动 8bit k/v 推理时,需要修改参数 `quant_policy``use_context_fmha`。详细内容请查阅 [kv int8](./kv_int8.md) 部署文档。
### 外推能力开关
设置 `use_dynamic_ntk = 1`,可以开启 RoPE 的 Dynamic NTK 选项,支持长文本推理。
关于 Dynamic NTK 的原理,详细请参考:
1. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
2. https://kexue.fm/archives/9675
设置 `use_logn_attn = 1`,可以开启 [LogN attention scaling](https://kexue.fm/archives/8823)
...@@ -44,15 +44,16 @@ class TurbomindModelConfig: ...@@ -44,15 +44,16 @@ class TurbomindModelConfig:
rope_theta: float = 10000.0 rope_theta: float = 10000.0
size_per_head: int = 128 size_per_head: int = 128
group_size: int = 0 group_size: int = 0
max_batch_size: int = 32 max_batch_size: int = 64
max_context_token_num: int = 4 max_context_token_num: int = 4
step_length: int = 1 step_length: int = 1
cache_max_entry_count: int = 48 cache_max_entry_count: float = 0.5
cache_block_seq_len: int = 128
cache_chunk_size: int = 1 cache_chunk_size: int = 1
use_context_fmha: int = 1 use_context_fmha: int = 1
quant_policy: int = 0 quant_policy: int = 0
max_position_embeddings: int = 0 max_position_embeddings: int = 0
use_dynamic_ntk: int = 0 rope_scaling_factor: float = 0.0
use_logn_attn: int = 0 use_logn_attn: int = 0
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment