Unverified Commit 902a3e16 authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

feat(quantization): kv cache use asymmetric (#218)

* feat(quantization): kv cache use asymmetric
parent c3290cad
......@@ -2,7 +2,7 @@
## Benchmark the Graphics Memory Usage
We take the [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) instruction model as the benchmark target. The benchmark process is as follows:
We take the [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) instruction model as the benchmark target. The benchmark process is as follows:
1. Use the `deploy.py` to convert the model, modify the maximum concurrent amount in `workspace`, and adjust the request amount in `llama_config.ini`
2. Compile the `bin/llama_triton_example` and get the graphics usage status of the fp16 version model under different batch_size settings
......@@ -28,25 +28,30 @@ Note that `kCacheKVInt8` and `WeightInt4` can be used simultaneously, and we wil
## Benchmark the Accuracy
Here we take the [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) instruction model as the benchmark target again. The benchmark process is as follows:
The quantification method is PTQ, and the related formula is as follows:
```
zp = (min+max) / 2
scale = (max-min) / 255
quant: q = round( (f-zp) / scale)
dequant: f = q * scale + zp
```
Here we take the [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) instruction model as the benchmark target again. The benchmark process is as follows:
1. Convert the model with `deploy.py` and run the docker service
2. Test the fp16 version accuracy with `client.py` using the dataset
3. Execute the quantization script to get the quantization parameters, and put them into the weights directory. Then modify the configuration file to make [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) option to be effective
4. Execute the `client.py` again to get the int8 version precision
The following table is the precision result obtained by the `kCacheKVInt8` method after quantizing 128 randomly selected data from the c4 dataset and testing it on the mmlu-social-science dataset, which has a total of 3065 multiple-choice questions:
| task | dataset | metric | fp16 | int8 | diff |
| :--: | :-----------------: | :----: | :---: | :---: | :---: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 |
We noticed that there is a slight improvement in the precision, and the differences are as follows:
| Type | Number |
| :--------------------------------------------: | :----: |
| fp16 version failed but int8 version got right | 72 |
| fp16 version got right but int8 version failed | 66 |
| failed in both fp16 and int8 version | 118 |
We have validated the quantization implementation on more datasets and larger models and will keep updating the results.
The following table is the precision result obtained by the `kCacheKVInt8` method after quantizing 128 randomly selected data from the c4 dataset and testing it with [opencompass](https://github.com/InternLM/opencompass).
| task | dataset | metric | int8 | fp16 | diff |
| :-----------: | :-------------: | :-----------: | :---: | :---: | :---: |
| Language | winogrande | accuracy | 60.77 | 61.48 | -0.71 |
| Knowledge | nq | score | 2.69 | 2.60 | +0.09 |
| Reasoning | gsm8k | accuracy | 33.28 | 34.72 | -1.44 |
| Reasoning | bbh | naive_average | 20.12 | 20.51 | -0.39 |
| Understanding | openbookqa_fact | accuracy | 82.40 | 82.20 | +0.20 |
| Understanding | eprstmt-dev | accuracy | 90.62 | 88.75 | +1.87 |
| Safety | crows_pairs | accuracy | 32.56 | 31.43 | +1.13 |
......@@ -2,7 +2,7 @@
## 显存测试
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。
测试对象为 [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) 模型。
测试方法:
1. 使用 `deploy.py` 转换模型,修改 `workspace` 配置中的最大并发数;调整 `llama_config.ini` 中的请求数
......@@ -29,7 +29,16 @@
## 精度测试
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。
量化方法是 PTQ,相关公式如下:
```
zp = (min+max) / 2
scale = (max-min) / 255
quant: q = round( (f-zp) / scale)
dequant: f = q * scale + zp
```
测试对象为 [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) 指令模型。
测试方法:
1.`deploy.py` 转换模型,运行 docker 服务
......@@ -37,18 +46,14 @@
3. 执行量化脚本,得到量化参数,放到 weights 目录;修改配置文件,使 [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) 选项生效
4. 再次执行 `client.py`,读取 int8 版本精度
以下是 `kCacheKVInt8` 方法仅从 c4 数据集,随机选择 128 条数据量化,在 mmlu-social-science 数据集的精度损失。
| task | dataset | metric | fp16 | int8 | diff |
| :--: | :-----------------: | :----: | :---: | :---: | :---: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 |
我们注意到精度有轻微提升,mmlu-social-science 一共 3065 个选择题,具体差异如下:
| 类型 | 个数 |
| :--------------------------: | :--: |
| fp16 版回答错误,int8 版变对 | 72 |
| fp16 版回答正确,int8 版变错 | 66 |
| 两版均答错且答案不同 | 118 |
我们已经在更大的模型上验证了更多数据集,将持续更新结果。
以下是 `kCacheKVInt8` 方法仅从 c4 数据集,随机选择 128 条数据 PTQ 量化。量化后使用 [opencompass](https://github.com/InternLM/opencompass) 测试。
| task | dataset | metric | int8 | fp16 | diff |
| :-----------: | :-------------: | :-----------: | :---: | :---: | :---: |
| Language | winogrande | accuracy | 60.77 | 61.48 | -0.71 |
| Knowledge | nq | score | 2.69 | 2.60 | +0.09 |
| Reasoning | gsm8k | accuracy | 33.28 | 34.72 | -1.44 |
| Reasoning | bbh | naive_average | 20.12 | 20.51 | -0.39 |
| Understanding | openbookqa_fact | accuracy | 82.40 | 82.20 | +0.20 |
| Understanding | eprstmt-dev | accuracy | 90.62 | 88.75 | +1.87 |
| Safety | crows_pairs | accuracy | 32.56 | 31.43 | +1.13 |
......@@ -63,7 +63,9 @@ def _export_asym(key_stats: dict,
tp_k_max = torch.chunk(k_max, tp)
tp_v_max = torch.chunk(v_max, tp)
for i in range(tp):
# quant: q = (f - zp) / scale
# zp = (min+max) / 2
# scale = (max-min) / 255
# quant: q = (f-zp) / scale
# dequant: f = q * scale + zp
k_min = tp_k_min[i].min()
v_min = tp_v_min[i].min()
......@@ -74,7 +76,10 @@ def _export_asym(key_stats: dict,
k_scale = (k_max - k_min) / (2**bits - 1)
v_scale = (v_max - v_min) / (2**bits - 1)
kv_qparams = np.array([k_scale, k_min, v_scale, v_min],
k_zp = (k_max + k_min) / 2
v_zp = (v_max + v_min) / 2
kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp],
dtype=np.float32)
out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501
kv_qparams.tofile(out_path)
......
......@@ -116,7 +116,9 @@ struct Multihead_attention_params_base {
const float* attention_out_scale = nullptr;
int int8_mode = 0;
float attention_k_scale = 0.f;
float attention_k_zp = 0.f;
float attention_v_scale = 0.f;
float attention_v_zp = 0.f;
};
template<typename T>
......
......@@ -1023,49 +1023,49 @@ inline __device__ Float8_ float_from_int8(int64_t u)
}
// clang-format on
inline __device__ int8_t quant(float a, const float scale)
inline __device__ int8_t quant(float a, const float scale, const float zp)
{
int8_t int8;
int8 = round(max(-128.f, min(127.f, a / scale)));
int8 = round(max(-128.f, min(127.f, (a - zp) / scale)));
return int8;
}
inline __device__ short quant(float2 a, const float scale)
inline __device__ short quant(float2 a, const float scale, const float zp)
{
union {
int8_t int8[2];
short int16;
};
int8[0] = round(max(-128.f, min(127.f, a.x / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale)));
int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale)));
return int16;
}
inline __device__ int32_t quant(float4 a, const float scale)
inline __device__ int32_t quant(float4 a, const float scale, const float zp)
{
union {
int8_t int8[4];
int32_t int32;
};
int8[0] = round(max(-128.f, min(127.f, a.x / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale)));
int8[2] = round(max(-128.f, min(127.f, a.z / scale)));
int8[3] = round(max(-128.f, min(127.f, a.w / scale)));
int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale)));
int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale)));
int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale)));
return int32;
}
// float16 to int8
inline __device__ int8_t quant(uint16_t a, const float scale)
inline __device__ int8_t quant(uint16_t a, const float scale, const float zp)
{
int8_t int8;
float b = half_to_float(a);
int8 = round(max(-128.f, min(127.f, b / scale)));
int8 = round(max(-128.f, min(127.f, (b - zp) / scale)));
return int8;
}
// float16x2 to int8x2
inline __device__ int16_t quant(uint a, const float scale)
inline __device__ int16_t quant(uint a, const float scale, const float zp)
{
union {
int8_t int8[2];
......@@ -1073,44 +1073,44 @@ inline __device__ int16_t quant(uint a, const float scale)
};
float2 b = half2_to_float2(a);
int8[0] = round(max(-128.f, min(127.f, b.x / scale)));
int8[1] = round(max(-128.f, min(127.f, b.y / scale)));
int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale)));
return int16;
}
// float16x4 to int8x4
inline __device__ int32_t quant(uint2 a, const float scale)
inline __device__ int32_t quant(uint2 a, const float scale, const float zp)
{
union {
int16_t int16[2];
int32_t int32;
};
int16[0] = quant(a.x, scale);
int16[1] = quant(a.y, scale);
int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
return int32;
}
// float16x8 to int8x8
inline __device__ int64_t quant(uint4 a, const float scale)
inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
{
union {
int16_t int16[4];
int64_t int64;
};
int16[0] = quant(a.x, scale);
int16[1] = quant(a.y, scale);
int16[2] = quant(a.z, scale);
int16[3] = quant(a.w, scale);
int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale, zp);
int16[2] = quant(a.z, scale, zp);
int16[3] = quant(a.w, scale, zp);
return int64;
}
// int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale)
inline __device__ float dequant(int8_t a, const float scale, const float zp)
{
float b = a * scale;
float b = a * scale + zp;
return b;
}
// int8x2 to float32x2
inline __device__ float2 dequant(int16_t a, const float scale)
inline __device__ float2 dequant(int16_t a, const float scale, const float zp)
{
union {
int8_t int8[2];
......@@ -1119,12 +1119,12 @@ inline __device__ float2 dequant(int16_t a, const float scale)
int16 = a;
float2 b;
b.x = int8[0] * scale;
b.y = int8[1] * scale;
b.x = int8[0] * scale + zp;
b.y = int8[1] * scale + zp;
return b;
}
// int8x4 to float32x4
inline __device__ float4 dequant(int32_t a, const float scale)
inline __device__ float4 dequant(int32_t a, const float scale, const float zp)
{
union {
int8_t int8[4];
......@@ -1133,14 +1133,14 @@ inline __device__ float4 dequant(int32_t a, const float scale)
int32 = a;
float4 b;
b.x = (int8[0] * scale);
b.y = (int8[1] * scale);
b.z = (int8[2] * scale);
b.w = (int8[3] * scale);
b.x = (int8[0] * scale) + zp;
b.y = (int8[1] * scale) + zp;
b.z = (int8[2] * scale) + zp;
b.w = (int8[3] * scale) + zp;
return b;
}
inline __device__ Float8_ dequant(int64_t a, const float scale)
inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp)
{
union {
int16_t int16[4];
......@@ -1149,10 +1149,10 @@ inline __device__ Float8_ dequant(int64_t a, const float scale)
int64 = a;
Float8_ b;
b.x = dequant(int16[0], scale);
b.y = dequant(int16[1], scale);
b.z = dequant(int16[2], scale);
b.w = dequant(int16[3], scale);
b.x = dequant(int16[0], scale, zp);
b.y = dequant(int16[1], scale, zp);
b.z = dequant(int16[2], scale, zp);
b.w = dequant(int16[3], scale, zp);
return b;
}
......@@ -1393,7 +1393,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// past kv quant param
const float k_scale = params.attention_k_scale;
const float k_zp = params.attention_k_zp;
const float v_scale = params.attention_v_scale;
const float v_zp = params.attention_v_zp;
// Trigger the loads from the Q and K buffers.
Qk_vec_k q;
......@@ -1472,7 +1474,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
......@@ -1495,7 +1497,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);
Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
......@@ -1643,7 +1645,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(
&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale, k_zp);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
}
......@@ -1829,7 +1831,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
}
......@@ -1876,7 +1878,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
}
......@@ -1934,7 +1936,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
}
else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale);
Packed_Int8_t v_int8 = quant(v, v_scale, v_zp);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
}
}
......
......@@ -299,7 +299,7 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
std::ifstream in(scale_path, std::ios::in);
if (in.is_open()) {
in.close();
self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path);
self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
}
else {
self_attn_weights.past_kv_scale = {};
......
......@@ -153,7 +153,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
if (int8_mode & QuantPolicy::kCacheKVInt8) {
params.attention_k_scale = attention_kv_scale[0];
params.attention_v_scale = attention_kv_scale[1];
params.attention_k_zp = attention_kv_scale[1];
params.attention_v_scale = attention_kv_scale[2];
params.attention_v_zp = attention_kv_scale[3];
}
PUSH_RANGE("scaled dot-product fusion");
......
......@@ -293,13 +293,21 @@ inline __device__ float2 float2div(float a, float2 b)
return c;
}
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale)
inline __device__ float2 float2sub(float zp, float2 val)
{
float2 ret;
ret.x = val.x - zp;
ret.y = val.y - zp;
return ret;
}
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale, const float zp)
{
half4 dst;
dst.x = __float2half(value.x * scale);
dst.y = __float2half(value.y * scale);
dst.z = __float2half(value.z * scale);
dst.w = __float2half(value.w * scale);
dst.x = __float2half(value.x * scale + zp);
dst.y = __float2half(value.y * scale + zp);
dst.z = __float2half(value.z * scale + zp);
dst.w = __float2half(value.w * scale + zp);
return dst;
}
......@@ -339,7 +347,8 @@ __global__ void extend_value_cache_int8(int8_t** v_dst,
const int* history_length,
const int max_q_len,
const int max_seq_len,
const float v_scale)
const float v_scale,
const float v_zp)
{
const int batch_id = blockIdx.y;
const int head_id = blockIdx.z;
......@@ -373,12 +382,12 @@ __global__ void extend_value_cache_int8(int8_t** v_dst,
const auto value = val_src[src_idx];
auto to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx);
float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x));
float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y));
float2 float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.x)));
float2 float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.y)));
to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
float2_0 = float2div(v_scale, mmha::half2_to_float2(value.z));
float2_1 = float2div(v_scale, mmha::half2_to_float2(value.w));
float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.z)));
float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.w)));
to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
}
}
......@@ -415,7 +424,8 @@ void invokeExtendKVCache(T** k_dst,
history_length,
max_q_len,
max_seq_len,
kv_scale[0]);
kv_scale[0],
kv_scale[1]);
extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(v_dst),
dst_offset,
......@@ -426,7 +436,8 @@ void invokeExtendKVCache(T** k_dst,
history_length,
max_q_len,
max_seq_len,
kv_scale[1]);
kv_scale[2],
kv_scale[3]);
}
else {
extend_value_cache<<<grid, block_sz, 0, stream>>>(k_dst,
......@@ -535,7 +546,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const int* seq_length,
const int max_kv_len,
const int max_seq_len,
const float v_scale)
const float v_scale,
const float v_zp)
{
const int batch_id = blockIdx.y;
const int head_id = blockIdx.z;
......@@ -568,8 +580,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
auto to_ptr = reinterpret_cast<half4*>(val_dst + dst_idx);
to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale);
to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale);
to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp);
to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp);
}
}
......@@ -605,7 +617,8 @@ void invokeTransposeKVCache(T* key_cache_trans,
key_length,
max_kv_len,
max_seq_len,
kv_scale[0]);
kv_scale[0],
kv_scale[1]);
transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(val_cache_trans,
reinterpret_cast<const int8_t**>(val_cache),
......@@ -616,7 +629,8 @@ void invokeTransposeKVCache(T* key_cache_trans,
key_length,
max_kv_len,
max_seq_len,
kv_scale[1]);
kv_scale[2],
kv_scale[3]);
}
else {
transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment