Unverified Commit cc93136e authored by tpoisonooo's avatar tpoisonooo Committed by GitHub
Browse files

feat(src): add kv cache int8 quantization (#22)

* feat(src): add int8 and compile passed

* feat(kernels): fix

* feat(llama): update kernel

* feat(src): add debug

* fix(kernel): k_cache use int8_t pointer

* style(llama): clean code

* feat(deploy.py): revert to enable fmha

* style(LlamaV2): clean code

* feat(deploy.py): add default quant policy
parent 4d42a781
...@@ -17,18 +17,18 @@ project(FasterTransformer LANGUAGES CXX CUDA) ...@@ -17,18 +17,18 @@ project(FasterTransformer LANGUAGES CXX CUDA)
find_package(CUDA 10.2 REQUIRED) find_package(CUDA 10.2 REQUIRED)
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") # if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
add_definitions("-DENABLE_BF16") # add_definitions("-DENABLE_BF16")
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") # message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
endif() # endif()
if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12")) # if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12"))
add_definitions("-DENABLE_FP8") # add_definitions("-DENABLE_FP8")
option(ENABLE_FP8 "ENABLE_FP8" OFF) # option(ENABLE_FP8 "ENABLE_FP8" OFF)
if(ENABLE_FP8) # if(ENABLE_FP8)
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.8, enable -DENABLE_FP8 flag") # message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.8, enable -DENABLE_FP8 flag")
endif() # endif()
endif() # endif()
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
......
...@@ -158,12 +158,19 @@ bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fast ...@@ -158,12 +158,19 @@ bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fast
python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1 python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1
``` ```
## User Guide ## Quantization
In fp16 mode, kv_cache int8 quantization can be enabled, and a single card can serve more users.
First execute the quantization script, and the quantization parameters are stored in the weight directory transformed by `deploy.py`.
Then adjust `config.ini`
* `use_context_fmha` changed to 0, means off
* `quant_policy` is set to 4. This parameter defaults to 0, which means it is not enabled
## Contributing ## Contributing
We appreciate all contributions to LLMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. We appreciate all contributions to LLMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
## Acknowledgement ## Acknowledgement
- [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) - [FasterTransformer](https://github.com/NVIDIA/FasterTransformer)
......
...@@ -68,7 +68,7 @@ pip install -e . ...@@ -68,7 +68,7 @@ pip install -e .
```shell ```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-7B /path/to/llama-7b llama \ python3 llmdeploy/serve/fastertransformer/deploy.py llama-7B /path/to/llama-7b llama \
--tokenizer_path /path/to/tokenizer/model --tokenizer_path /path/to/tokenizer/model
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -79,7 +79,7 @@ bash workspace/service_docker_up.sh ...@@ -79,7 +79,7 @@ bash workspace/service_docker_up.sh
```shell ```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-13B /path/to/llama-13b llama \ python3 llmdeploy/serve/fastertransformer/deploy.py llama-13B /path/to/llama-13b llama \
--tokenizer_path /path/to/tokenizer/model --tp 2 --tokenizer_path /path/to/tokenizer/model --tp 2
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -90,7 +90,7 @@ bash workspace/service_docker_up.sh ...@@ -90,7 +90,7 @@ bash workspace/service_docker_up.sh
```shell ```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-33B /path/to/llama-33b llama \ python3 llmdeploy/serve/fastertransformer/deploy.py llama-33B /path/to/llama-33b llama \
--tokenizer_path /path/to/tokenizer/model --tp 4 --tokenizer_path /path/to/tokenizer/model --tp 4
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -101,7 +101,7 @@ bash workspace/service_docker_up.sh ...@@ -101,7 +101,7 @@ bash workspace/service_docker_up.sh
```shell ```shell
python3 llmdeploy/serve/fastertransformer/deploy.py llama-65B /path/to/llama-65b llama \ python3 llmdeploy/serve/fastertransformer/deploy.py llama-65B /path/to/llama-65b llama \
--tokenizer_path /path/to/tokenizer/model --tp 8 --tokenizer_path /path/to/tokenizer/model --tp 8
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -119,7 +119,7 @@ python3 -m fastchat.model.apply_delta \ ...@@ -119,7 +119,7 @@ python3 -m fastchat.model.apply_delta \
--delta-path lmsys/vicuna-7b-delta-v1.1 --delta-path lmsys/vicuna-7b-delta-v1.1
python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-7B /path/to/vicuna-7b hf python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-7B /path/to/vicuna-7b hf
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -135,7 +135,7 @@ python3 -m fastchat.model.apply_delta \ ...@@ -135,7 +135,7 @@ python3 -m fastchat.model.apply_delta \
--delta-path lmsys/vicuna-13b-delta-v1.1 --delta-path lmsys/vicuna-13b-delta-v1.1
python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-13B /path/to/vicuna-13b hf python3 llmdeploy/serve/fastertransformer/deploy.py vicuna-13B /path/to/vicuna-13b hf
bash workspace/service_docker_up.sh bash workspace/service_docker_up.sh --lib-dir $(pwd)/build/install/backends/fastertransformer
``` ```
</details> </details>
...@@ -146,6 +146,13 @@ bash workspace/service_docker_up.sh ...@@ -146,6 +146,13 @@ bash workspace/service_docker_up.sh
python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1 python3 llmdeploy/serve/client.py {server_ip_addresss}:33337 1
``` ```
## 量化部署
在 fp16 模式下,可以开启 kv_cache int8 量化,单卡可服务更多用户。
首先执行量化脚本,量化参数存放到 `deploy.py` 转换的 weight 目录下。
然后调整 `config.ini`
* `use_context_fmha` 改为 0,表示关闭
* `quant_policy` 设置为 4。此参数默认为 0,表示不开启
## 贡献指南 ## 贡献指南
我们感谢所有的贡献者为改进和提升 LLMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。 我们感谢所有的贡献者为改进和提升 LLMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
......
...@@ -147,7 +147,8 @@ def export(model_name: str, ...@@ -147,7 +147,8 @@ def export(model_name: str,
step_length=1, step_length=1,
cache_max_entry_count=48, cache_max_entry_count=48,
cache_chunk_size=8, cache_chunk_size=8,
use_context_fmha=1)) use_context_fmha=1,
quant_policy=0))
config = configparser.ConfigParser() config = configparser.ConfigParser()
for section, key_values in cfg.items(): for section, key_values in cfg.items():
...@@ -301,6 +302,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -301,6 +302,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
_files = [file for file in os.listdir(model_path) if file.endswith('.bin')] _files = [file for file in os.listdir(model_path) if file.endswith('.bin')]
_files = sorted(_files) _files = sorted(_files)
print(_files)
_params = {} _params = {}
for _file in _files: for _file in _files:
...@@ -369,7 +371,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, ...@@ -369,7 +371,7 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
for ft, hf in other: for ft, hf in other:
model_params[ft] = get_tensor(hf) model_params[ft] = get_tensor(hf)
return export(model_name, i + 1, norm_eps, model_params, tokenizer_path, return export(model_name, num_layer, norm_eps, model_params, tokenizer_path,
triton_models_path, tp) triton_models_path, tp)
......
...@@ -116,6 +116,8 @@ struct Multihead_attention_params_base { ...@@ -116,6 +116,8 @@ struct Multihead_attention_params_base {
const float* qkv_scale_out = nullptr; const float* qkv_scale_out = nullptr;
const float* attention_out_scale = nullptr; const float* attention_out_scale = nullptr;
int int8_mode = 0; int int8_mode = 0;
float attention_k_scale = 0.f;
float attention_v_scale = 0.f;
}; };
template<typename T> template<typename T>
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh" #include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include <assert.h> #include <assert.h>
#include <float.h> #include <float.h>
#include <type_traits> #include <type_traits>
...@@ -548,6 +549,43 @@ __inline__ __device__ Tout vec_conversion(const Tin& x) ...@@ -548,6 +549,43 @@ __inline__ __device__ Tout vec_conversion(const Tin& x)
{ {
return x; return x;
} }
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, float4>(const float4& a)
{
uint2 b;
float2 val;
val.x = a.x;
val.y = a.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.z;
val.y = a.w;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
template<>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
// fp8_t // fp8_t
template<> template<>
...@@ -980,6 +1018,139 @@ inline __device__ Float8_ float_from_int8(int64_t u) ...@@ -980,6 +1018,139 @@ inline __device__ Float8_ float_from_int8(int64_t u)
} }
// clang-format on // clang-format on
inline __device__ int8_t quant(float a, const float scale)
{
int8_t int8;
int8 = round(max(-128.f, min(127.f, a / scale)));
return int8;
}
inline __device__ short quant(float2 a, const float scale)
{
union {
int8_t int8[2];
short int16;
};
int8[0] = round(max(-128.f, min(127.f, a.x / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale)));
return int16;
}
inline __device__ int32_t quant(float4 a, const float scale)
{
union {
int8_t int8[4];
int32_t int32;
};
int8[0] = round(max(-128.f, min(127.f, a.x / scale)));
int8[1] = round(max(-128.f, min(127.f, a.y / scale)));
int8[2] = round(max(-128.f, min(127.f, a.z / scale)));
int8[3] = round(max(-128.f, min(127.f, a.w / scale)));
return int32;
}
// float16 to int8
inline __device__ int8_t quant(uint16_t a, const float scale)
{
int8_t int8;
float b = half_to_float(a);
int8 = round(max(-128.f, min(127.f, b / scale)));
return int8;
}
// float16x2 to int8x2
inline __device__ int16_t quant(uint a, const float scale)
{
union {
int8_t int8[2];
short int16;
};
float2 b = half2_to_float2(a);
int8[0] = round(max(-128.f, min(127.f, b.x / scale)));
int8[1] = round(max(-128.f, min(127.f, b.y / scale)));
return int16;
}
// float16x4 to int8x4
inline __device__ int32_t quant(uint2 a, const float scale)
{
union {
int16_t int16[2];
int32_t int32;
};
int16[0] = quant(a.x, scale);
int16[1] = quant(a.y, scale);
return int32;
}
// float16x8 to int8x8
inline __device__ int64_t quant(uint4 a, const float scale)
{
union {
int16_t int16[4];
int64_t int64;
};
int16[0] = quant(a.x, scale);
int16[1] = quant(a.y, scale);
int16[2] = quant(a.z, scale);
int16[3] = quant(a.w, scale);
return int64;
}
// int8 to float32, then `vec_conversion` to target format
inline __device__ float dequant(int8_t a, const float scale)
{
float b = a * scale;
return b;
}
// int8x2 to float32x2
inline __device__ float2 dequant(int16_t a, const float scale)
{
union {
int8_t int8[2];
int16_t int16;
};
int16 = a;
float2 b;
b.x = int8[0] * scale;
b.y = int8[1] * scale;
return b;
}
// int8x4 to float32x4
inline __device__ float4 dequant(int32_t a, const float scale)
{
union {
int8_t int8[4];
int32_t int32;
};
int32 = a;
float4 b;
b.x = (int8[0] * scale);
b.y = (int8[1] * scale);
b.z = (int8[2] * scale);
b.w = (int8[3] * scale);
return b;
}
inline __device__ Float8_ dequant(int64_t a, const float scale)
{
union {
int16_t int16[4];
int64_t int64;
};
int64 = a;
Float8_ b;
b.x = dequant(int16[0], scale);
b.y = dequant(int16[1], scale);
b.z = dequant(int16[2], scale);
b.w = dequant(int16[3], scale);
return b;
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ int8_t cast_to_int8(float val) inline __device__ int8_t cast_to_int8(float val)
...@@ -1208,45 +1379,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1208,45 +1379,24 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The offset in the bias buffer. // The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; // past kv quant param
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; const float k_scale = params.attention_k_scale;
const float v_scale = params.attention_v_scale;
// Trigger the loads from the Q and K buffers. // Trigger the loads from the Q and K buffers.
Qk_vec_k q; Qk_vec_k q;
zero(q); zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
const auto q_scaling = params.qkv_scale_out[0];
const auto q_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
}
else {
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset])); q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset]));
} }
}
Qk_vec_k k; Qk_vec_k k;
zero(k); zero(k);
{ {
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
const auto k_scaling = params.qkv_scale_out[1];
const auto k_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
}
else {
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) : vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :
k; k;
} }
}
// Trigger the loads from the Q and K bias buffers. // Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias; Qk_vec_k q_bias;
...@@ -1270,12 +1420,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1270,12 +1420,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (handle_kv) { if (handle_kv) {
k = add(k, k_bias); k = add(k, k_bias);
} }
if (do_ia3 && !is_masked) {
k = mul<Qk_vec_k, Qk_vec_k, Qk_vec_k>(
k,
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
&params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])));
}
// Padded len // Padded len
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
...@@ -1311,8 +1455,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1311,8 +1455,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B
+ tlength_circ * QK_ELTS_IN_16B + ci; + tlength_circ * QK_ELTS_IN_16B + ci;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k); *reinterpret_cast<Qk_vec_m*>(&params.k_cache[offset]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
}
else { else {
int offset; int offset;
if (params.k_cache_interleaved) { if (params.k_cache_interleaved) {
...@@ -1323,9 +1476,19 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1323,9 +1476,19 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci; + co * QK_ELTS_IN_16B + ci;
} }
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec_k>::value>::type;
Packed_Int8_t k_int8 = quant(k, k_scale);
int8_t* dst_ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
*reinterpret_cast<Packed_Int8_t*>(&dst_ptr[offset]) = k_int8;
} else {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
vec_conversion<Qk_vec_m, Qk_vec_k>(k); vec_conversion<Qk_vec_m, Qk_vec_k>(k);
} }
}
} }
} }
...@@ -1402,13 +1565,28 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1402,13 +1565,28 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer. // The base pointer for the key in the cache buffer.
T* k_cache_batch = nullptr;
int8_t* k_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
// convert k_cache_per_sample to int8
if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
k_cache_batch_int8 = &ptr[bhi * params.memory_max_len * Dh + ki];
}
} else {
T* k_cache = T* k_cache =
params.k_cache_per_sample ? params.k_cache_per_sample ?
(params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) : (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki]; &params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki]; // T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
T* k_cache_batch = k_cache; k_cache_batch = k_cache;
}
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
...@@ -1439,14 +1617,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1439,14 +1617,21 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
k[ii] = k_vec_zero; k[ii] = k_vec_zero;
} }
else { else {
int beam_offset = 0;
if (HAS_BEAMS) { if (HAS_BEAMS) {
const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
k[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} }
else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>( if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B]))); using Packed_Int8_t = typename packed_type<int8_t, num_elems<K_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<K_vec_m>::value>::type;
Packed_Int8_t k_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&k_cache_batch_int8[beam_offset + jj * QK_ELTS_IN_16B]);
Packed_Float_t k_vec_m_float = dequant(k_vec_m_int8, k_scale);
k[ii] = vec_conversion<K_vec_k, Packed_Float_t>(k_vec_m_float);
} else {
k[ii] = vec_conversion<K_vec_k, K_vec_m>((*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} }
} }
} }
...@@ -1562,13 +1747,32 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1562,13 +1747,32 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer. // The base pointer for the value in the cache buffer.
T* v_cache = T* v_cache = nullptr;
T* v_cache_batch = nullptr;
int8_t* v_cache_int8 = nullptr;
int8_t* v_cache_batch_int8 = nullptr;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
} else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
v_cache_int8 = &ptr[bhi * params.memory_max_len * Dh + vi];
}
v_cache_batch_int8 = v_cache_int8;
} else {
v_cache =
params.v_cache_per_sample ? params.v_cache_per_sample ?
(params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) : (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi]; &params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; // T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
T* v_cache_batch = v_cache; v_cache_batch = v_cache;
}
// The number of values processed per iteration of the loop. // The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
...@@ -1606,7 +1810,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1606,7 +1810,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Loop over the timesteps to compute the partial outputs. // Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) { if (Dh == Dh_MAX || vi < Dh) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_m>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec_m>::value>::type;
// Separate the ti < memory_max_len and ti > memory_max_len // Separate the ti < memory_max_len and ti > memory_max_len
// to prevent ti % memory_len when ti < memory_len, and // to prevent ti % memory_len when ti < memory_len, and
// the compiler cannot optimize the codes automatically. // the compiler cannot optimize the codes automatically.
...@@ -1616,8 +1821,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1616,8 +1821,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>( V_vec_k v;
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti * Dh]));
}
// Load the logits from shared memory. // Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti - first_step]; float logit = logits_smem[ti - first_step];
...@@ -1652,8 +1866,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1652,8 +1866,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0;
// Load the values from the cache. // Load the values from the cache.
V_vec_k v = vec_conversion<V_vec_k, V_vec_m>( V_vec_k v;
*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
Packed_Int8_t v_vec_m_int8 = *reinterpret_cast<const Packed_Int8_t*>(&v_cache_batch_int8[beam_offset + ti_circ * Dh]);
Packed_Float_t v_vec_m_float = dequant(v_vec_m_int8, v_scale);
v = vec_conversion<V_vec_k, Packed_Float_t>(v_vec_m_float);
} else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&v_cache_batch[beam_offset + ti_circ * Dh]));
}
// Load the logits from shared memory. // Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float logit = logits_smem[ti - first_step]; float logit = logits_smem[ti - first_step];
...@@ -1687,18 +1910,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1687,18 +1910,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the loads from the V buffer. // Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi; const auto v_offset = qkv_base_offset + vi;
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec_k>::value>::type;
const auto v_scaling = params.qkv_scale_out[2];
const auto v_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
}
else {
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset])); v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
}
// Trigger the loads from the V bias buffer. // Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]); // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
...@@ -1706,17 +1918,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1706,17 +1918,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (handle_kv) { if (handle_kv) {
v = add(v, v_bias); v = add(v, v_bias);
if (do_ia3) {
v = mul<V_vec_k, V_vec_k, V_vec_k>(
v,
*reinterpret_cast<const V_vec_k*>(
&params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
}
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v; //*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
if (params.int8_mode & QuantPolicy::kCacheKVInt8) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
Packed_Int8_t v_int8 = quant(v, v_scale);
*reinterpret_cast<Packed_Int8_t*>(&v_cache_int8[tlength_circ * Dh]) = v_int8;
} else {
*reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v); *reinterpret_cast<V_vec_m*>(&v_cache[tlength_circ * Dh]) = vec_conversion<V_vec_m, V_vec_k>(v);
} }
}
// Initialize the output value with the current timestep. // Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...@@ -1782,14 +1994,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1782,14 +1994,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]),
mul<V_vec_acum, float, V_vec_acum>(result_scale, out)); mul<V_vec_acum, float, V_vec_acum>(result_scale, out));
#endif // FP8_MHA #endif // FP8_MHA
} } else {
else if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
cast_to_int8(out);
}
else {
convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out); convert_from_float(*reinterpret_cast<V_vec_m*>(&params.out[bhi * Dh + vi]), out);
} }
#else // MMHA_USE_FP32_ACUM_FOR_OUT #else // MMHA_USE_FP32_ACUM_FOR_OUT
......
...@@ -195,9 +195,12 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -195,9 +195,12 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
max_seq_len, max_seq_len,
size_per_head_, size_per_head_,
local_head_num_, local_head_num_,
stream_); stream_,
sync_check_cuda_error(); quant_policy_,
weights->past_kv_scale.data());
sync_check_cuda_error();
if (use_fmha_) { if (use_fmha_) {
fusedMultiHeadAttention(k_cache_ptrs, fusedMultiHeadAttention(k_cache_ptrs,
v_cache_ptrs, v_cache_ptrs,
...@@ -220,7 +223,11 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -220,7 +223,11 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
num_token, num_token,
max_q_len, max_q_len,
max_k_len, max_k_len,
max_seq_len); max_seq_len,
quant_policy_,
weights->past_kv_scale.data());
} }
////////////////////////////////////////////// //////////////////////////////////////////////
...@@ -233,6 +240,8 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -233,6 +240,8 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
sync_check_cuda_error(); sync_check_cuda_error();
} }
if (is_free_buffer_after_forward_ == true) { if (is_free_buffer_after_forward_ == true) {
freeBuffer(); freeBuffer();
} }
...@@ -303,7 +312,9 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac ...@@ -303,7 +312,9 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac
int num_token, int num_token,
int max_q_len, int max_q_len,
int max_k_len, int max_k_len,
int max_seq_len) int max_seq_len,
int quant,
const float* kv_scale)
{ {
// key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] // key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D]
// val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] // val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D]
...@@ -318,7 +329,9 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac ...@@ -318,7 +329,9 @@ void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T** key_cac
max_seq_len, max_seq_len,
size_per_head_, size_per_head_,
local_head_num_, local_head_num_,
stream_); stream_,
quant,
kv_scale);
sync_check_cuda_error(); sync_check_cuda_error();
const T qk_scale = static_cast<T>(1.f / sqrtf(size_per_head_ * 1.f)); const T qk_scale = static_cast<T>(1.f / sqrtf(size_per_head_ * 1.f));
......
...@@ -42,7 +42,8 @@ public: ...@@ -42,7 +42,8 @@ public:
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha): bool use_fmha,
int quant_policy):
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
...@@ -56,7 +57,8 @@ public: ...@@ -56,7 +57,8 @@ public:
linear_(cublas_wrapper, stream), linear_(cublas_wrapper, stream),
allocator_(allocator), allocator_(allocator),
is_free_buffer_after_forward_(is_free_buffer_after_forward), is_free_buffer_after_forward_(is_free_buffer_after_forward),
use_fmha_(use_fmha) use_fmha_(use_fmha),
quant_policy_(quant_policy)
{ {
} }
...@@ -82,7 +84,9 @@ public: ...@@ -82,7 +84,9 @@ public:
int num_token, int num_token,
int max_q_len, int max_q_len,
int max_k_len, int max_k_len,
int max_seq_len); int max_seq_len,
int quant_policy,
const float* kv_scale);
private: private:
const size_t head_num_; const size_t head_num_;
...@@ -96,6 +100,7 @@ private: ...@@ -96,6 +100,7 @@ private:
const bool neox_rotary_style_; const bool neox_rotary_style_;
const bool use_fmha_; const bool use_fmha_;
const int quant_policy_;
NcclParam tensor_para_; NcclParam tensor_para_;
......
...@@ -62,7 +62,7 @@ void LlamaContextDecoder<T>::freeBuffer() ...@@ -62,7 +62,7 @@ void LlamaContextDecoder<T>::freeBuffer()
} }
template<typename T> template<typename T>
void LlamaContextDecoder<T>::initialize(bool use_fmha) void LlamaContextDecoder<T>::initialize(bool use_fmha, int quant_policy)
{ {
h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);
...@@ -75,7 +75,8 @@ void LlamaContextDecoder<T>::initialize(bool use_fmha) ...@@ -75,7 +75,8 @@ void LlamaContextDecoder<T>::initialize(bool use_fmha)
cublas_wrapper_, cublas_wrapper_,
allocator_, allocator_,
is_free_buffer_after_forward_, is_free_buffer_after_forward_,
use_fmha); use_fmha,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_, silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
size_per_head_, size_per_head_,
...@@ -133,7 +134,8 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, ...@@ -133,7 +134,8 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha): bool use_fmha,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
...@@ -145,7 +147,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num, ...@@ -145,7 +147,7 @@ LlamaContextDecoder<T>::LlamaContextDecoder(size_t head_num,
tensor_para_(tensor_para), tensor_para_(tensor_para),
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
initialize(use_fmha); initialize(use_fmha, quant_policy);
} }
template<typename T> template<typename T>
......
...@@ -42,7 +42,7 @@ protected: ...@@ -42,7 +42,7 @@ protected:
void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len);
void freeBuffer() override; void freeBuffer() override;
void initialize(bool use_fmha); void initialize(bool use_fmha, int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
...@@ -97,7 +97,8 @@ public: ...@@ -97,7 +97,8 @@ public:
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward, bool is_free_buffer_after_forward,
bool use_fmha); bool use_fmha,
int quant_policy);
~LlamaContextDecoder() override; ~LlamaContextDecoder() override;
......
...@@ -37,7 +37,8 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num, ...@@ -37,7 +37,8 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward): bool is_free_buffer_after_forward,
int quant_policy):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward), BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
...@@ -50,7 +51,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num, ...@@ -50,7 +51,7 @@ LlamaDecoder<T>::LlamaDecoder(size_t head_num,
data_type_(getTensorType<T>()) data_type_(getTensorType<T>())
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
initialize(); initialize(quant_policy);
} }
template<typename T> template<typename T>
...@@ -62,7 +63,7 @@ LlamaDecoder<T>::~LlamaDecoder() ...@@ -62,7 +63,7 @@ LlamaDecoder<T>::~LlamaDecoder()
} }
template<typename T> template<typename T>
void LlamaDecoder<T>::initialize() void LlamaDecoder<T>::initialize(int quant_policy)
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
...@@ -74,7 +75,8 @@ void LlamaDecoder<T>::initialize() ...@@ -74,7 +75,8 @@ void LlamaDecoder<T>::initialize()
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
allocator_, allocator_,
is_free_buffer_after_forward_); is_free_buffer_after_forward_,
quant_policy);
silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_, silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
size_per_head_, size_per_head_,
......
...@@ -35,7 +35,7 @@ protected: ...@@ -35,7 +35,7 @@ protected:
void allocateBuffer() override; // deprecated void allocateBuffer() override; // deprecated
void allocateBuffer(size_t batch_size); void allocateBuffer(size_t batch_size);
void freeBuffer() override; void freeBuffer() override;
void initialize(); void initialize(int quant_policy);
size_t head_num_; size_t head_num_;
size_t size_per_head_; size_t size_per_head_;
...@@ -79,7 +79,9 @@ public: ...@@ -79,7 +79,9 @@ public:
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward); bool is_free_buffer_after_forward,
int quant_policy),
~LlamaDecoder() override; ~LlamaDecoder() override;
......
...@@ -158,6 +158,17 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType ...@@ -158,6 +158,17 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type); loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type); loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type);
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type); loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type);
// load kv_cache quant scale
// if file not exist, get empty vector
std::string scale_path = dir_path + ".past_kv_scale." + rank_spec + ".weight";
std::ifstream in(scale_path, std::ios::in);
if (in.is_open()) {
in.close();
self_attn_weights.past_kv_scale = loadArrayFromBin({2}, scale_path);
} else {
self_attn_weights.past_kv_scale = {};
}
} }
template struct LlamaDecoderLayerWeight<float>; template struct LlamaDecoderLayerWeight<float>;
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
// Modified from // Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h" #include "src/fastertransformer/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/models/llama/LlamaNcclGuard.h" #include "src/fastertransformer/models/llama/LlamaNcclGuard.h"
...@@ -75,6 +74,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -75,6 +74,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
const float* qkv_scale_out, const float* qkv_scale_out,
const float* attention_out_scale, const float* attention_out_scale,
const int int8_mode, const int int8_mode,
const float* attention_kv_scale,
cudaStream_t stream) cudaStream_t stream)
{ {
using DataType = typename SATypeConverter<T>::Type; using DataType = typename SATypeConverter<T>::Type;
...@@ -98,14 +98,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -98,14 +98,9 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
// Set the input buffers. // Set the input buffers.
params.q = reinterpret_cast<const DataType*>(qkv_buf); params.q = reinterpret_cast<const DataType*>(qkv_buf);
if (int8_mode != 2) {
params.k = reinterpret_cast<const DataType*>(qkv_buf) + hidden_units; params.k = reinterpret_cast<const DataType*>(qkv_buf) + hidden_units;
params.v = reinterpret_cast<const DataType*>(qkv_buf) + 2 * hidden_units; params.v = reinterpret_cast<const DataType*>(qkv_buf) + 2 * hidden_units;
}
else {
params.k = reinterpret_cast<const DataType*>(reinterpret_cast<const int8_t*>(qkv_buf) + hidden_units);
params.v = reinterpret_cast<const DataType*>(reinterpret_cast<const int8_t*>(qkv_buf) + 2 * hidden_units);
}
params.stride = 3 * hidden_units; params.stride = 3 * hidden_units;
params.finished = const_cast<bool*>(finished); params.finished = const_cast<bool*>(finished);
...@@ -148,9 +143,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, ...@@ -148,9 +143,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.ia3_value_weights = reinterpret_cast<const DataType*>(ia3_value_weights); params.ia3_value_weights = reinterpret_cast<const DataType*>(ia3_value_weights);
params.int8_mode = int8_mode; params.int8_mode = int8_mode;
if (int8_mode == 2) {
params.qkv_scale_out = qkv_scale_out; if (int8_mode & QuantPolicy::kCacheKVInt8) {
params.attention_out_scale = attention_out_scale; params.attention_k_scale = attention_kv_scale[0];
params.attention_v_scale = attention_kv_scale[1];
} }
PUSH_RANGE("scaled dot-product fusion"); PUSH_RANGE("scaled dot-product fusion");
...@@ -269,7 +265,8 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o ...@@ -269,7 +265,8 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
nullptr, // ia3_value_weights nullptr, // ia3_value_weights
nullptr, // qkv_scale_out nullptr, // qkv_scale_out
nullptr, // attention_out_scale nullptr, // attention_out_scale
0, // int8_mode quant_policy_, // int8_mode
weights->past_kv_scale.data(), // attention kv scale
stream_); stream_);
sync_check_cuda_error(); sync_check_cuda_error();
......
...@@ -40,7 +40,8 @@ public: ...@@ -40,7 +40,8 @@ public:
cudaStream_t stream, cudaStream_t stream,
cublasMMWrapper* cublas_wrapper, cublasMMWrapper* cublas_wrapper,
IAllocator* allocator, IAllocator* allocator,
bool is_free_buffer_after_forward): bool is_free_buffer_after_forward,
int quant_policy):
head_num_(head_num), head_num_(head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
...@@ -52,7 +53,8 @@ public: ...@@ -52,7 +53,8 @@ public:
stream_(stream), stream_(stream),
linear_(cublas_wrapper, stream), linear_(cublas_wrapper, stream),
allocator_(allocator), allocator_(allocator),
is_free_buffer_after_forward_(is_free_buffer_after_forward) is_free_buffer_after_forward_(is_free_buffer_after_forward),
quant_policy_(quant_policy)
{ {
} }
...@@ -71,6 +73,7 @@ private: ...@@ -71,6 +73,7 @@ private:
const size_t local_hidden_units_; const size_t local_hidden_units_;
const size_t rotary_embedding_dim_; const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
const int quant_policy_;
const bool neox_rotary_style_; const bool neox_rotary_style_;
......
...@@ -66,6 +66,7 @@ template<typename T> ...@@ -66,6 +66,7 @@ template<typename T>
struct LlamaAttentionWeight { struct LlamaAttentionWeight {
LlamaDenseWeight<T> qkv; LlamaDenseWeight<T> qkv;
LlamaDenseWeight<T> output; LlamaDenseWeight<T> output;
std::vector<float> past_kv_scale;
}; };
template<typename T> template<typename T>
......
...@@ -52,6 +52,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -52,6 +52,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
int end_id, int end_id,
int cache_max_entry_count, int cache_max_entry_count,
int cache_chunk_size, int cache_chunk_size,
int quant_policy,
bool use_context_fmha, bool use_context_fmha,
std::shared_ptr<SharedState> shared_state, std::shared_ptr<SharedState> shared_state,
LlamaWeight<T>* weights, LlamaWeight<T>* weights,
...@@ -89,16 +90,26 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -89,16 +90,26 @@ LlamaV2<T>::LlamaV2(size_t head_num,
FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0); FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0);
FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); FT_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8;
if (use_context_fmha) {
FT_LOG_ERROR("use_context_fmha not support int8");
assert(0);
}
} else {
elem_bits = sizeof(T) * 8;
}
kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_, kv_cache_mgr_ = std::make_unique<LlamaCacheManager>(num_layer_,
local_head_num_, local_head_num_,
size_per_head_, size_per_head_,
session_len, session_len,
sizeof(T) * 8, elem_bits,
cache_max_entry_count, cache_max_entry_count,
cache_chunk_size, cache_chunk_size,
tensor_para.rank_, tensor_para.rank_,
allocator); allocator);
initialize(use_context_fmha); initialize(use_context_fmha, quant_policy);
start(); start();
} }
...@@ -113,7 +124,7 @@ LlamaV2<T>::~LlamaV2() ...@@ -113,7 +124,7 @@ LlamaV2<T>::~LlamaV2()
} }
template<typename T> template<typename T>
void LlamaV2<T>::initialize(bool use_context_fmha) void LlamaV2<T>::initialize(bool use_context_fmha, int quant_policy)
{ {
FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_LOG_DEBUG(__PRETTY_FUNCTION__);
...@@ -128,7 +139,8 @@ void LlamaV2<T>::initialize(bool use_context_fmha) ...@@ -128,7 +139,8 @@ void LlamaV2<T>::initialize(bool use_context_fmha)
cublas_wrapper_, cublas_wrapper_,
allocator_, allocator_,
is_free_buffer_after_forward_, is_free_buffer_after_forward_,
use_context_fmha); use_context_fmha,
quant_policy);
decoder_ = new LlamaDecoder<T>(head_num_, decoder_ = new LlamaDecoder<T>(head_num_,
size_per_head_, size_per_head_,
...@@ -140,7 +152,8 @@ void LlamaV2<T>::initialize(bool use_context_fmha) ...@@ -140,7 +152,8 @@ void LlamaV2<T>::initialize(bool use_context_fmha)
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
allocator_, allocator_,
is_free_buffer_after_forward_); is_free_buffer_after_forward_,
quant_policy);
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_, dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_, // vocab_size_padded, vocab_size_, // vocab_size_padded,
......
...@@ -62,6 +62,7 @@ public: ...@@ -62,6 +62,7 @@ public:
int end_id, int end_id,
int cache_max_entry_count, int cache_max_entry_count,
int cache_chunk_size, int cache_chunk_size,
int quant_policy,
bool use_context_fmha, bool use_context_fmha,
std::shared_ptr<SharedState> shared_state, std::shared_ptr<SharedState> shared_state,
LlamaWeight<T>* weights, LlamaWeight<T>* weights,
...@@ -88,7 +89,7 @@ private: ...@@ -88,7 +89,7 @@ private:
void internalThreadEntry(int device_id); void internalThreadEntry(int device_id);
void initialize(bool use_context_fmha); void initialize(bool use_context_fmha, int quant_policy);
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/models/llama/llama_kernels.h" #include "src/fastertransformer/models/llama/llama_kernels.h"
#include "src/fastertransformer/models/llama/llama_utils.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh" #include "src/fastertransformer/utils/cuda_type_utils.cuh"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
namespace fastertransformer { namespace fastertransformer {
...@@ -283,6 +285,102 @@ __global__ void extend_value_cache(T** v_dst, ...@@ -283,6 +285,102 @@ __global__ void extend_value_cache(T** v_dst,
} }
} }
inline __device__ float2 float2div(float a, float2 b)
{
float2 c;
c.x = b.x / a;
c.y = b.y / a;
return c;
}
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale) {
half4 dst;
dst.x = __float2half(value.x * scale);
dst.y = __float2half(value.y * scale);
dst.z = __float2half(value.z * scale);
dst.w = __float2half(value.w * scale);
return dst;
}
static inline __device__ uint32_t float4_to_char4(float x,
float y,
float z,
float w) {
uint32_t dst;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720
uint32_t a; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
uint32_t b; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y));
uint32_t c; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z));
uint32_t d; asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;\n" : "=r"(dst) : "r"(d), "r"(c));
asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
#else
char4 tmp;
tmp.x = x;
tmp.y = y;
tmp.z = z;
tmp.w = w;
dst = reinterpret_cast<const uint32_t&>(tmp);
#endif
return dst;
}
template<typename T>
__global__ void extend_value_cache_int8(int8_t** v_dst,
const size_t dst_offset,
const T* v_src,
const int head_num,
const int size_per_head,
const int* query_length,
const int* history_length,
const int max_q_len,
const int max_seq_len,
const float v_scale)
{
const int batch_id = blockIdx.y;
const int head_id = blockIdx.z;
constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8;
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
int size_per_head_div_x = size_per_head / X_ELEMS;
// x dim is now handled by uint4 type
const auto val_src = reinterpret_cast<const uint4*>(v_src);
const auto val_dst = reinterpret_cast<uint2*>(v_dst[batch_id] + dst_offset);
const auto seq_len = query_length[batch_id];
const auto t_offset = history_length[batch_id];
const int v_head_size_id = idx % size_per_head_div_x;
const int v_seq_len_id = idx / size_per_head_div_x;
if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] -> [H, S[t:t+s], D/x]
const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len + // H
(v_seq_len_id + t_offset) * size_per_head_div_x + // s + offset
v_head_size_id; // D/x
const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len + // B
head_id * size_per_head_div_x * max_q_len + // H
v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x
// scale to int8 and write
const auto value = val_src[src_idx];
auto to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx);
float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x));
float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y));
to_ptr[0] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
float2_0 = float2div(v_scale, mmha::half2_to_float2(value.z));
float2_1 = float2div(v_scale, mmha::half2_to_float2(value.w));
to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
}
}
template<typename T> template<typename T>
void invokeExtendKVCache(T** k_dst, void invokeExtendKVCache(T** k_dst,
T** v_dst, T** v_dst,
...@@ -296,18 +394,29 @@ void invokeExtendKVCache(T** k_dst, ...@@ -296,18 +394,29 @@ void invokeExtendKVCache(T** k_dst,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int local_head_num, int local_head_num,
cudaStream_t stream) cudaStream_t stream,
int quant,
const float* kv_scale)
{ {
constexpr int block_sz = 128; constexpr int block_sz = 128;
constexpr int x = (sizeof(T) == 4) ? 4 : 8; constexpr int x = (sizeof(T) == 4) ? 4 : 8;
dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num);
if (quant & QuantPolicy::kCacheKVInt8) {
extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(
reinterpret_cast<int8_t**>(k_dst), dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len, kv_scale[0]);
extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(
reinterpret_cast<int8_t**>(v_dst), dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len, kv_scale[1]);
} else {
extend_value_cache<<<grid, block_sz, 0, stream>>>( extend_value_cache<<<grid, block_sz, 0, stream>>>(
k_dst, dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); k_dst, dst_offset, k_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len);
extend_value_cache<<<grid, block_sz, 0, stream>>>( extend_value_cache<<<grid, block_sz, 0, stream>>>(
v_dst, dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len); v_dst, dst_offset, v_src, local_head_num, size_per_head, query_length, history_length, max_q_len, max_seq_len);
}
} }
template void invokeExtendKVCache(float**, template void invokeExtendKVCache(float**,
...@@ -322,7 +431,9 @@ template void invokeExtendKVCache(float**, ...@@ -322,7 +431,9 @@ template void invokeExtendKVCache(float**,
int, int,
int, int,
int, int,
cudaStream_t stream); cudaStream_t stream,
int,
const float*);
template void invokeExtendKVCache(half**, template void invokeExtendKVCache(half**,
half**, half**,
...@@ -336,11 +447,13 @@ template void invokeExtendKVCache(half**, ...@@ -336,11 +447,13 @@ template void invokeExtendKVCache(half**,
int, int,
int, int,
int, int,
cudaStream_t stream); cudaStream_t stream,
int,
const float*);
template<typename T> template<typename T>
__global__ void transpose_key_cache(T* k_dst, __global__ void transpose_value_cache(T* v_dst, //
const T** k_src, const T** v_src,
const size_t src_offset, const size_t src_offset,
const int head_num, const int head_num,
const int size_per_head, const int size_per_head,
...@@ -356,39 +469,40 @@ __global__ void transpose_key_cache(T* k_dst, ...@@ -356,39 +469,40 @@ __global__ void transpose_key_cache(T* k_dst,
int size_per_head_div_x = size_per_head / X_ELEMS; int size_per_head_div_x = size_per_head / X_ELEMS;
// x dim is now handled by uint4 type // x dim is now handled by uint4 type
const auto key_src = reinterpret_cast<const uint4*>(k_src[batch_id] + src_offset); const auto val_src = reinterpret_cast<const uint4*>(v_src[batch_id] + src_offset);
const auto key_dst = reinterpret_cast<uint4*>(k_dst); const auto val_dst = reinterpret_cast<uint4*>(v_dst);
const auto seq_len = seq_length[batch_id]; const auto seq_len = seq_length[batch_id];
const int k_head_size_id = idx % size_per_head_div_x; const int v_head_size_id = idx % size_per_head_div_x;
const int k_seq_len_id = idx / size_per_head_div_x; const int v_seq_len_id = idx / size_per_head_div_x;
if (k_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, D/x, S[:s]]
if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, S[:s], D/x]
const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H
k_head_size_id * max_seq_len + // D/x v_seq_len_id * size_per_head_div_x + // s
k_seq_len_id; // s v_head_size_id; // D/x
const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B
head_id * size_per_head_div_x * max_kv_len + // H head_id * size_per_head_div_x * max_kv_len + // H
k_seq_len_id * size_per_head_div_x + // s v_seq_len_id * size_per_head_div_x + // s
k_head_size_id; // D/x v_head_size_id; // D/x
key_dst[dst_idx] = key_src[src_idx]; val_dst[dst_idx] = val_src[src_idx];
} }
} }
template<typename T> template<typename T>
__global__ void transpose_value_cache(T* v_dst, // __global__ void transpose_value_cache_int8(T* v_dst, //
const T** v_src, const int8_t** v_src,
const size_t src_offset, const size_t src_offset,
const int head_num, const int head_num,
const int size_per_head, const int size_per_head,
const int* seq_length, const int* seq_length,
const int max_kv_len, const int max_kv_len,
const int max_seq_len) const int max_seq_len,
const float v_scale)
{ {
const int batch_id = blockIdx.y; const int batch_id = blockIdx.y;
const int head_id = blockIdx.z; const int head_id = blockIdx.z;
...@@ -398,7 +512,7 @@ __global__ void transpose_value_cache(T* v_dst, // ...@@ -398,7 +512,7 @@ __global__ void transpose_value_cache(T* v_dst, //
int size_per_head_div_x = size_per_head / X_ELEMS; int size_per_head_div_x = size_per_head / X_ELEMS;
// x dim is now handled by uint4 type // x dim is now handled by uint4 type
const auto val_src = reinterpret_cast<const uint4*>(v_src[batch_id] + src_offset); const auto val_src = reinterpret_cast<const uint2*>(v_src[batch_id] + src_offset);
const auto val_dst = reinterpret_cast<uint4*>(v_dst); const auto val_dst = reinterpret_cast<uint4*>(v_dst);
const auto seq_len = seq_length[batch_id]; const auto seq_len = seq_length[batch_id];
...@@ -417,7 +531,12 @@ __global__ void transpose_value_cache(T* v_dst, // ...@@ -417,7 +531,12 @@ __global__ void transpose_value_cache(T* v_dst, //
v_seq_len_id * size_per_head_div_x + // s v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x v_head_size_id; // D/x
val_dst[dst_idx] = val_src[src_idx]; // int8x8 -> fp16x8
const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
auto to_ptr = reinterpret_cast<half4*>(val_dst + dst_idx);
to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale);
to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale);
} }
} }
...@@ -433,24 +552,35 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -433,24 +552,35 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int head_num, int head_num,
cudaStream_t stream) cudaStream_t stream,
int quant,
const float* kv_scale)
{ {
constexpr int block_sz = 128; constexpr int block_sz = 128;
constexpr int x = (sizeof(T) == 4) ? 4 : 8; constexpr int x = (sizeof(T) == 4) ? 4 : 8;
dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num); dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num);
if (quant & QuantPolicy::kCacheKVInt8) {
transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(
key_cache_trans, reinterpret_cast<const int8_t**>(key_cache), src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len, kv_scale[0]);
transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(
val_cache_trans, reinterpret_cast<const int8_t**>(val_cache), src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len, kv_scale[1]);
} else {
transpose_value_cache<<<grid, block_sz, 0, stream>>>( transpose_value_cache<<<grid, block_sz, 0, stream>>>(
key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);
transpose_value_cache<<<grid, block_sz, 0, stream>>>( transpose_value_cache<<<grid, block_sz, 0, stream>>>(
val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);
}
} }
template void invokeTransposeKVCache( template void invokeTransposeKVCache(
float*, float*, const float**, const float**, size_t, int, const int*, int, int, int, int, cudaStream_t stream); float*, float*, const float**, const float**, size_t, int, const int*, int, int, int, int, cudaStream_t stream, int, const float*);
template void invokeTransposeKVCache( template void invokeTransposeKVCache(
half*, half*, const half**, const half**, size_t, int, const int*, int, int, int, int, cudaStream_t stream); half*, half*, const half**, const half**, size_t, int, const int*, int, int, int, int, cudaStream_t stream, int, const float*);
__global__ void gatherOutput(int* output_ids, __global__ void gatherOutput(int* output_ids,
const int* ids, const int* ids,
......
...@@ -46,7 +46,9 @@ void invokeExtendKVCache(T** k_dst, ...@@ -46,7 +46,9 @@ void invokeExtendKVCache(T** k_dst,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int local_head_num, int local_head_num,
cudaStream_t stream); cudaStream_t stream,
int quant,
const float* kv_scale);
template<typename T> template<typename T>
void invokeTransposeKVCache(T* key_cache_trans, void invokeTransposeKVCache(T* key_cache_trans,
...@@ -60,7 +62,9 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -60,7 +62,9 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int head_num, int head_num,
cudaStream_t stream); cudaStream_t stream,
int quant_policy,
const float* kv_scale);
void invokeGatherOutput(int* output_ids, void invokeGatherOutput(int* output_ids,
const int* ids, const int* ids,
...@@ -151,25 +155,6 @@ struct TempBuffer { ...@@ -151,25 +155,6 @@ struct TempBuffer {
T* data; T* data;
}; };
template<typename T>
inline T*
transpose_key_cache(T* key_cache, size_t head_num, size_t size_per_head_by_x, size_t mem_len, size_t x, cudaStream_t st)
{
static TempBuffer<T> buf(8192 * 8192);
// from: H Dx, S, x
// to : S, H Dx, x
invokeTransposeAxis01(buf.data, key_cache, head_num * size_per_head_by_x, mem_len, x, st);
return buf.data;
}
template<typename T>
inline T* transpose_value_cache(T* value_cache, size_t head_num, size_t mem_len, size_t size_per_head, cudaStream_t st)
{
static TempBuffer<T> buf(8192 * 8192);
invokeTransposeAxis01(buf.data, value_cache, head_num, mem_len, size_per_head, st);
return buf.data;
}
inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st) inline void dump_sequence_len(int* d_seq_len, int step, int tp_rank, cudaStream_t st)
{ {
int h_seq_len = -1; int h_seq_len = -1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment