Unverified Commit 902a3e16 authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

feat(quantization): kv cache use asymmetric (#218)

* feat(quantization): kv cache use asymmetric
parent c3290cad
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## Benchmark the Graphics Memory Usage ## Benchmark the Graphics Memory Usage
We take the [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) instruction model as the benchmark target. The benchmark process is as follows: We take the [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) instruction model as the benchmark target. The benchmark process is as follows:
1. Use the `deploy.py` to convert the model, modify the maximum concurrent amount in `workspace`, and adjust the request amount in `llama_config.ini` 1. Use the `deploy.py` to convert the model, modify the maximum concurrent amount in `workspace`, and adjust the request amount in `llama_config.ini`
2. Compile the `bin/llama_triton_example` and get the graphics usage status of the fp16 version model under different batch_size settings 2. Compile the `bin/llama_triton_example` and get the graphics usage status of the fp16 version model under different batch_size settings
...@@ -28,25 +28,30 @@ Note that `kCacheKVInt8` and `WeightInt4` can be used simultaneously, and we wil ...@@ -28,25 +28,30 @@ Note that `kCacheKVInt8` and `WeightInt4` can be used simultaneously, and we wil
## Benchmark the Accuracy ## Benchmark the Accuracy
Here we take the [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) instruction model as the benchmark target again. The benchmark process is as follows: The quantification method is PTQ, and the related formula is as follows:
```
zp = (min+max) / 2
scale = (max-min) / 255
quant: q = round( (f-zp) / scale)
dequant: f = q * scale + zp
```
Here we take the [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) instruction model as the benchmark target again. The benchmark process is as follows:
1. Convert the model with `deploy.py` and run the docker service 1. Convert the model with `deploy.py` and run the docker service
2. Test the fp16 version accuracy with `client.py` using the dataset 2. Test the fp16 version accuracy with `client.py` using the dataset
3. Execute the quantization script to get the quantization parameters, and put them into the weights directory. Then modify the configuration file to make [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) option to be effective 3. Execute the quantization script to get the quantization parameters, and put them into the weights directory. Then modify the configuration file to make [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) option to be effective
4. Execute the `client.py` again to get the int8 version precision 4. Execute the `client.py` again to get the int8 version precision
The following table is the precision result obtained by the `kCacheKVInt8` method after quantizing 128 randomly selected data from the c4 dataset and testing it on the mmlu-social-science dataset, which has a total of 3065 multiple-choice questions: The following table is the precision result obtained by the `kCacheKVInt8` method after quantizing 128 randomly selected data from the c4 dataset and testing it with [opencompass](https://github.com/InternLM/opencompass).
| task | dataset | metric | fp16 | int8 | diff | | task | dataset | metric | int8 | fp16 | diff |
| :--: | :-----------------: | :----: | :---: | :---: | :---: | | :-----------: | :-------------: | :-----------: | :---: | :---: | :---: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 | | Language | winogrande | accuracy | 60.77 | 61.48 | -0.71 |
| Knowledge | nq | score | 2.69 | 2.60 | +0.09 |
We noticed that there is a slight improvement in the precision, and the differences are as follows: | Reasoning | gsm8k | accuracy | 33.28 | 34.72 | -1.44 |
| Reasoning | bbh | naive_average | 20.12 | 20.51 | -0.39 |
| Type | Number | | Understanding | openbookqa_fact | accuracy | 82.40 | 82.20 | +0.20 |
| :--------------------------------------------: | :----: | | Understanding | eprstmt-dev | accuracy | 90.62 | 88.75 | +1.87 |
| fp16 version failed but int8 version got right | 72 | | Safety | crows_pairs | accuracy | 32.56 | 31.43 | +1.13 |
| fp16 version got right but int8 version failed | 66 |
| failed in both fp16 and int8 version | 118 |
We have validated the quantization implementation on more datasets and larger models and will keep updating the results.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## 显存测试 ## 显存测试
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。 测试对象为 [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) 模型。
测试方法: 测试方法:
1. 使用 `deploy.py` 转换模型,修改 `workspace` 配置中的最大并发数;调整 `llama_config.ini` 中的请求数 1. 使用 `deploy.py` 转换模型,修改 `workspace` 配置中的最大并发数;调整 `llama_config.ini` 中的请求数
...@@ -29,7 +29,16 @@ ...@@ -29,7 +29,16 @@
## 精度测试 ## 精度测试
测试对象为 [Chinese-LLaMa-Alpaca 7B](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 指令模型。 量化方法是 PTQ,相关公式如下:
```
zp = (min+max) / 2
scale = (max-min) / 255
quant: q = round( (f-zp) / scale)
dequant: f = q * scale + zp
```
测试对象为 [internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b) 指令模型。
测试方法: 测试方法:
1.`deploy.py` 转换模型,运行 docker 服务 1.`deploy.py` 转换模型,运行 docker 服务
...@@ -37,18 +46,14 @@ ...@@ -37,18 +46,14 @@
3. 执行量化脚本,得到量化参数,放到 weights 目录;修改配置文件,使 [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) 选项生效 3. 执行量化脚本,得到量化参数,放到 weights 目录;修改配置文件,使 [kCacheKVInt8](../../src/turbomind/models/llama/llama_utils.h) 选项生效
4. 再次执行 `client.py`,读取 int8 版本精度 4. 再次执行 `client.py`,读取 int8 版本精度
以下是 `kCacheKVInt8` 方法仅从 c4 数据集,随机选择 128 条数据量化,在 mmlu-social-science 数据集的精度损失。 以下是 `kCacheKVInt8` 方法仅从 c4 数据集,随机选择 128 条数据 PTQ 量化。量化后使用 [opencompass](https://github.com/InternLM/opencompass) 测试。
| task | dataset | metric | fp16 | int8 | diff | | task | dataset | metric | int8 | fp16 | diff |
| :--: | :-----------------: | :----: | :---: | :---: | :---: | | :-----------: | :-------------: | :-----------: | :---: | :---: | :---: |
| Exam | mmlu-social-science | score | 31.81 | 32.00 | +0.19 | | Language | winogrande | accuracy | 60.77 | 61.48 | -0.71 |
| Knowledge | nq | score | 2.69 | 2.60 | +0.09 |
我们注意到精度有轻微提升,mmlu-social-science 一共 3065 个选择题,具体差异如下: | Reasoning | gsm8k | accuracy | 33.28 | 34.72 | -1.44 |
| Reasoning | bbh | naive_average | 20.12 | 20.51 | -0.39 |
| 类型 | 个数 | | Understanding | openbookqa_fact | accuracy | 82.40 | 82.20 | +0.20 |
| :--------------------------: | :--: | | Understanding | eprstmt-dev | accuracy | 90.62 | 88.75 | +1.87 |
| fp16 版回答错误,int8 版变对 | 72 | | Safety | crows_pairs | accuracy | 32.56 | 31.43 | +1.13 |
| fp16 版回答正确,int8 版变错 | 66 |
| 两版均答错且答案不同 | 118 |
我们已经在更大的模型上验证了更多数据集,将持续更新结果。
...@@ -63,7 +63,9 @@ def _export_asym(key_stats: dict, ...@@ -63,7 +63,9 @@ def _export_asym(key_stats: dict,
tp_k_max = torch.chunk(k_max, tp) tp_k_max = torch.chunk(k_max, tp)
tp_v_max = torch.chunk(v_max, tp) tp_v_max = torch.chunk(v_max, tp)
for i in range(tp): for i in range(tp):
# quant: q = (f - zp) / scale # zp = (min+max) / 2
# scale = (max-min) / 255
# quant: q = (f-zp) / scale
# dequant: f = q * scale + zp # dequant: f = q * scale + zp
k_min = tp_k_min[i].min() k_min = tp_k_min[i].min()
v_min = tp_v_min[i].min() v_min = tp_v_min[i].min()
...@@ -74,7 +76,10 @@ def _export_asym(key_stats: dict, ...@@ -74,7 +76,10 @@ def _export_asym(key_stats: dict,
k_scale = (k_max - k_min) / (2**bits - 1) k_scale = (k_max - k_min) / (2**bits - 1)
v_scale = (v_max - v_min) / (2**bits - 1) v_scale = (v_max - v_min) / (2**bits - 1)
kv_qparams = np.array([k_scale, k_min, v_scale, v_min], k_zp = (k_max + k_min) / 2
v_zp = (v_max + v_min) / 2
kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp],
dtype=np.float32) dtype=np.float32)
out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501 out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight' # noqa: E501
kv_qparams.tofile(out_path) kv_qparams.tofile(out_path)
......
...@@ -116,7 +116,9 @@ struct Multihead_attention_params_base { ...@@ -116,7 +116,9 @@ struct Multihead_attention_params_base {
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
float attention_k_scale = 0.f; float attention_k_scale = 0.f;
float attention_k_zp = 0.f;
float attention_v_scale = 0.f; float attention_v_scale = 0.f;
float attention_v_zp = 0.f;
}; };
template<typename T> template<typename T>
......
...@@ -1023,49 +1023,49 @@ inline __device__ Float8_ float_from_int8(int64_t u) ...@@ -1023,49 +1023,49 @@ inline __device__ Float8_ float_from_int8(int64_t u)
} }
// clang-format on // clang-format on
inline __device__ int8_t quant(float a, const float scale) inline __device__ int8_t quant(float a, const float scale, const float zp)
{ {
int8_t int8; int8_t int8;
int8 = round(max(-128.f, min(127.f, a / scale))); int8 = round(max(-128.f, min(127.f, (a - zp) / scale)));
return int8; return int8;
} }
inline __device__ short quant(float2 a, const float scale) inline __device__ short quant(float2 a, const float scale, const float zp)
{ {
union { union {
int8_t int8[2]; int8_t int8[2];
short int16; short int16;
}; };
int8[0] = round(max(-128.f, min(127.f, a.x / scale))); int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale))); int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale)));
return int16; return int16;
} }
inline __device__ int32_t quant(float4 a, const float scale) inline __device__ int32_t quant(float4 a, const float scale, const float zp)
{ {
union { union {
int8_t int8[4]; int8_t int8[4];
int32_t int32; int32_t int32;
}; };
int8[0] = round(max(-128.f, min(127.f, a.x / scale))); int8[0] = round(max(-128.f, min(127.f, (a.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale))); int8[1] = round(max(-128.f, min(127.f, (a.y - zp) / scale)));
int8[2] = round(max(-128.f, min(127.f, a.z / scale))); int8[2] = round(max(-128.f, min(127.f, (a.z - zp) / scale)));
int8[3] = round(max(-128.f, min(127.f, a.w / scale))); int8[3] = round(max(-128.f, min(127.f, (a.w - zp) / scale)));
return int32; return int32;
} }
// float16 to int8 // float16 to int8
inline __device__ int8_t quant(uint16_t a, const float scale) inline __device__ int8_t quant(uint16_t a, const float scale, const float zp)
{ {
int8_t int8; int8_t int8;
float b = half_to_float(a); float b = half_to_float(a);
int8 = round(max(-128.f, min(127.f, b / scale))); int8 = round(max(-128.f, min(127.f, (b - zp) / scale)));
return int8; return int8;
} }
// float16x2 to int8x2 // float16x2 to int8x2
inline __device__ int16_t quant(uint a, const float scale) inline __device__ int16_t quant(uint a, const float scale, const float zp)
{ {
union { union {
int8_t int8[2]; int8_t int8[2];
...@@ -1073,44 +1073,44 @@ inline __device__ int16_t quant(uint a, const float scale) ...@@ -1073,44 +1073,44 @@ inline __device__ int16_t quant(uint a, const float scale)
}; };
float2 b = half2_to_float2(a); float2 b = half2_to_float2(a);
int8[0] = round(max(-128.f, min(127.f, b.x / scale))); int8[0] = round(max(-128.f, min(127.f, (b.x - zp) / scale)));
int8[1] = round(max(-128.f, min(127.f, b.y / scale))); int8[1] = round(max(-128.f, min(127.f, (b.y - zp) / scale)));
return int16; return int16;
} }
// float16x4 to int8x4 // float16x4 to int8x4
inline __device__ int32_t quant(uint2 a, const float scale) inline __device__ int32_t quant(uint2 a, const float scale, const float zp)
{ {
union { union {
int16_t int16[2]; int16_t int16[2];
int32_t int32; int32_t int32;
}; };
int16[0] = quant(a.x, scale); int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale); int16[1] = quant(a.y, scale, zp);
return int32; return int32;
} }
// float16x8 to int8x8 // float16x8 to int8x8
inline __device__ int64_t quant(uint4 a, const float scale) inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
{ {
union { union {
int16_t int16[4]; int16_t int16[4];
int64_t int64; int64_t int64;
}; };
int16[0] = quant(a.x, scale); int16[0] = quant(a.x, scale, zp);
int16[1] = quant(a.y, scale); int16[1] = quant(a.y, scale, zp);
int16[2] = quant(a.z, scale); int16[2] = quant(a.z, scale, zp);
int16[3] = quant(a.w, scale); int16[3] = quant(a.w, scale, zp);
return int64; return int64;
} }
// int8 to float32, then `vec_conversion` to target format // int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale) inline __device__ float dequant(int8_t a, const float scale, const float zp)
{ {
float b = a * scale; float b = a * scale + zp;
return b; return b;
} }
// int8x2 to float32x2 // int8x2 to float32x2
inline __device__ float2 dequant(int16_t a, const float scale) inline __device__ float2 dequant(int16_t a, const float scale, const float zp)
{ {
union { union {
int8_t int8[2]; int8_t int8[2];
...@@ -1119,12 +1119,12 @@ inline __device__ float2 dequant(int16_t a, const float scale) ...@@ -1119,12 +1119,12 @@ inline __device__ float2 dequant(int16_t a, const float scale)
int16 = a; int16 = a;
float2 b; float2 b;
b.x = int8[0] * scale; b.x = int8[0] * scale + zp;
b.y = int8[1] * scale; b.y = int8[1] * scale + zp;
return b; return b;
} }
// int8x4 to float32x4 // int8x4 to float32x4
inline __device__ float4 dequant(int32_t a, const float scale) inline __device__ float4 dequant(int32_t a, const float scale, const float zp)
{ {
union { union {
int8_t int8[4]; int8_t int8[4];
...@@ -1133,14 +1133,14 @@ inline __device__ float4 dequant(int32_t a, const float scale) ...@@ -1133,14 +1133,14 @@ inline __device__ float4 dequant(int32_t a, const float scale)
int32 = a; int32 = a;
float4 b; float4 b;
b.x = (int8[0] * scale); b.x = (int8[0] * scale) + zp;
b.y = (int8[1] * scale); b.y = (int8[1] * scale) + zp;
b.z = (int8[2] * scale); b.z = (int8[2] * scale) + zp;
b.w = (int8[3] * scale); b.w = (int8[3] * scale) + zp;
return b; return b;
} }
inline __device__ Float8_ dequant(int64_t a, const float scale) inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp)
{ {
union { union {
int16_t int16[4]; int16_t int16[4];
...@@ -1149,10 +1149,10 @@ inline __device__ Float8_ dequant(int64_t a, const float scale) ...@@ -1149,10 +1149,10 @@ inline __device__ Float8_ dequant(int64_t a, const float scale)
int64 = a; int64 = a;
Float8_ b; Float8_ b;
b.x = dequant(int16[0], scale); b.x = dequant(int16[0], scale, zp);
b.y = dequant(int16[1], scale); b.y = dequant(int16[1], scale, zp);
b.z = dequant(int16[2], scale); b.z = dequant(int16[2], scale, zp);
b.w = dequant(int16[3], scale); b.w = dequant(int16[3], scale, zp);
return b; return b;
} }
...@@ -1393,7 +1393,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1393,7 +1393,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// past kv quant param // past kv quant param
const float k_scale = params.attention_k_scale; const float k_scale = params.attention_k_scale;
const float k_zp = params.attention_k_zp;
const float v_scale = params.attention_v_scale; const float v_scale = params.attention_v_scale;
const float v_zp = params.attention_v_zp;
// Trigger the loads from the Q and K buffers. // Trigger the loads from the Q and K buffers.
Qk_vec_k q; Qk_vec_k q;
...@@ -1472,7 +1474,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1472,7 +1474,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
...@@ -1495,7 +1497,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1495,7 +1497,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale); Packed_Int8_t k_int8 = quant(k, k_scale, k_zp);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]); int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8; *reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
...@@ -1643,7 +1645,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1643,7 +1645,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>( Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(
&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]); &k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale); Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale, k_zp);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float); k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} }
...@@ -1829,7 +1831,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1829,7 +1831,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} }
...@@ -1876,7 +1878,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1876,7 +1878,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
Packed_Int8_t v_vec_m_int8 = Packed_Int8_t v_vec_m_int8 =
*reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]); *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale); Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale, v_zp);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float); v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} }
...@@ -1934,7 +1936,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1934,7 +1936,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
} }
else if (QUANT_POLICY == 4) { else if (QUANT_POLICY == 4) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale); Packed_Int8_t v_int8 = quant(v, v_scale, v_zp);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8; *reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} }
} }
......
...@@ -299,7 +299,7 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType ...@@ -299,7 +299,7 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
std::ifstream in(scale_path, std::ios::in); std::ifstream in(scale_path, std::ios::in);
if (in.is_open()) { if (in.is_open()) {
in.close(); in.close();
self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path); self_attn_weights.past_kv_scale = loadArrayFromBin({4}, scale_path);
} }
else { else {
self_attn_weights.past_kv_scale = {}; self_attn_weights.past_kv_scale = {};
......
...@@ -153,7 +153,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -153,7 +153,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
if (int8_mode & QuantPolicy::kCacheKVInt8) { if (int8_mode & QuantPolicy::kCacheKVInt8) {
params.attention_k_scale = attention_kv_scale[0]; params.attention_k_scale = attention_kv_scale[0];
params.attention_v_scale = attention_kv_scale[1]; params.attention_k_zp = attention_kv_scale[1];
params.attention_v_scale = attention_kv_scale[2];
params.attention_v_zp = attention_kv_scale[3];
} }
PUSH_RANGE("scaled dot-product fusion"); PUSH_RANGE("scaled dot-product fusion");
......
...@@ -293,13 +293,21 @@ inline __device__ float2 float2div(float a, float2 b) ...@@ -293,13 +293,21 @@ inline __device__ float2 float2div(float a, float2 b)
return c; return c;
} }
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale) inline __device__ float2 float2sub(float zp, float2 val)
{
float2 ret;
ret.x = val.x - zp;
ret.y = val.y - zp;
return ret;
}
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale, const float zp)
{ {
half4 dst; half4 dst;
dst.x = __float2half(value.x * scale); dst.x = __float2half(value.x * scale + zp);
dst.y = __float2half(value.y * scale); dst.y = __float2half(value.y * scale + zp);
dst.z = __float2half(value.z * scale); dst.z = __float2half(value.z * scale + zp);
dst.w = __float2half(value.w * scale); dst.w = __float2half(value.w * scale + zp);
return dst; return dst;
} }
...@@ -339,7 +347,8 @@ __global__ void extend_value_cache_int8(int8_t** v_dst, ...@@ -339,7 +347,8 @@ __global__ void extend_value_cache_int8(int8_t** v_dst,
const int* history_length, const int* history_length,
const int max_q_len, const int max_q_len,
const int max_seq_len, const int max_seq_len,
const float v_scale) const float v_scale,
const float v_zp)
{ {
const int batch_id = blockIdx.y; const int batch_id = blockIdx.y;
const int head_id = blockIdx.z; const int head_id = blockIdx.z;
...@@ -373,12 +382,12 @@ __global__ void extend_value_cache_int8(int8_t** v_dst, ...@@ -373,12 +382,12 @@ __global__ void extend_value_cache_int8(int8_t** v_dst,
const auto value = val_src[src_idx]; const auto value = val_src[src_idx];
auto to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx); auto to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx);
float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x)); float2 float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.x)));
float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y)); float2 float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.y)));
to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
float2_0 = float2div(v_scale, mmha::half2_to_float2(value.z)); float2_0 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.z)));
float2_1 = float2div(v_scale, mmha::half2_to_float2(value.w)); float2_1 = float2div(v_scale, float2sub(v_zp, mmha::half2_to_float2(value.w)));
to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y); to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
} }
} }
...@@ -415,7 +424,8 @@ void invokeExtendKVCache(T** k_dst, ...@@ -415,7 +424,8 @@ void invokeExtendKVCache(T** k_dst,
history_length, history_length,
max_q_len, max_q_len,
max_seq_len, max_seq_len,
kv_scale[0]); kv_scale[0],
kv_scale[1]);
extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(v_dst), extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(v_dst),
dst_offset, dst_offset,
...@@ -426,7 +436,8 @@ void invokeExtendKVCache(T** k_dst, ...@@ -426,7 +436,8 @@ void invokeExtendKVCache(T** k_dst,
history_length, history_length,
max_q_len, max_q_len,
max_seq_len, max_seq_len,
kv_scale[1]); kv_scale[2],
kv_scale[3]);
} }
else { else {
extend_value_cache<<<grid, block_sz, 0, stream>>>(k_dst, extend_value_cache<<<grid, block_sz, 0, stream>>>(k_dst,
...@@ -535,7 +546,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, // ...@@ -535,7 +546,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const int* seq_length, const int* seq_length,
const int max_kv_len, const int max_kv_len,
const int max_seq_len, const int max_seq_len,
const float v_scale) const float v_scale,
const float v_zp)
{ {
const int batch_id = blockIdx.y; const int batch_id = blockIdx.y;
const int head_id = blockIdx.z; const int head_id = blockIdx.z;
...@@ -568,8 +580,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, // ...@@ -568,8 +580,8 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx); const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
auto to_ptr = reinterpret_cast<half4*>(val_dst + dst_idx); auto to_ptr = reinterpret_cast<half4*>(val_dst + dst_idx);
to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale); to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale, v_zp);
to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale); to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale, v_zp);
} }
} }
...@@ -605,7 +617,8 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -605,7 +617,8 @@ void invokeTransposeKVCache(T* key_cache_trans,
key_length, key_length,
max_kv_len, max_kv_len,
max_seq_len, max_seq_len,
kv_scale[0]); kv_scale[0],
kv_scale[1]);
transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(val_cache_trans, transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(val_cache_trans,
reinterpret_cast<const int8_t**>(val_cache), reinterpret_cast<const int8_t**>(val_cache),
...@@ -616,7 +629,8 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -616,7 +629,8 @@ void invokeTransposeKVCache(T* key_cache_trans,
key_length, key_length,
max_kv_len, max_kv_len,
max_seq_len, max_seq_len,
kv_scale[1]); kv_scale[2],
kv_scale[3]);
} }
else { else {
transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans, transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment