Commit 45273722 authored by xiabo's avatar xiabo
Browse files

add kvint8

parent 17e4dd25
...@@ -33,6 +33,8 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -33,6 +33,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
...@@ -280,6 +282,12 @@ __device__ void paged_attention_kernel( ...@@ -280,6 +282,12 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant,
*k_scale);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
...@@ -410,6 +418,12 @@ __device__ void paged_attention_kernel( ...@@ -410,6 +418,12 @@ __device__ void paged_attention_kernel(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
*v_scale);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
......
...@@ -14,6 +14,8 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -14,6 +14,8 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
...@@ -311,6 +313,12 @@ __device__ void paged_attention_kernel_opt( ...@@ -311,6 +313,12 @@ __device__ void paged_attention_kernel_opt(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
k_vec_quant,
*k_scale_ptr);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
...@@ -478,6 +486,13 @@ __device__ void paged_attention_kernel_opt( ...@@ -478,6 +486,13 @@ __device__ void paged_attention_kernel_opt(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
// printf("======xiabo_kvint8\n");
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
*v_scale_ptr);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
......
...@@ -15,6 +15,7 @@ enum class Fp8KVCacheDataType { ...@@ -15,6 +15,7 @@ enum class Fp8KVCacheDataType {
kAuto = 0, kAuto = 0,
kFp8E4M3 = 1, kFp8E4M3 = 1,
kFp8E5M2 = 2, kFp8E5M2 = 2,
kInt8 = 3,
}; };
// fp8 vector types for quantization of kv cache // fp8 vector types for quantization of kv cache
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "quantization/fp8/nvidia/quant_utils.cuh" #include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include "quantization/int8_kvcache/quant_utils.cuh"
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <map> #include <map>
...@@ -252,6 +254,13 @@ __global__ void reshape_and_cache_kernel( ...@@ -252,6 +254,13 @@ __global__ void reshape_and_cache_kernel(
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else { } else {
key_cache[tgt_key_idx] = key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
...@@ -296,6 +305,13 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -296,6 +305,13 @@ __global__ void reshape_and_cache_flash_kernel(
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_value_idx] = tgt_key; key_cache[tgt_key_value_idx] = tgt_key;
value_cache[tgt_key_value_idx] = tgt_value; value_cache[tgt_key_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_key_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else { } else {
key_cache[tgt_key_value_idx] = key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
......
...@@ -653,6 +653,16 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -653,6 +653,16 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} else { \ } else { \
TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "int8") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8); \
} else { \
TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
if (SRC_DTYPE == at::ScalarType::Float) { \ if (SRC_DTYPE == at::ScalarType::Float) { \
......
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include <stdio.h>
namespace vllm {
namespace int8 {
// KV-CACHE int8
static inline __device__ float int8_to_float(uint8_t x, const float scale) {
int8_t a = x - 128;
float res = a * scale;
return res;
}
static inline __device__ uint8_t float_to_int8(float x, const float scale) {
int8_t fx = roundf(max(-128.f, min(127.f, x / scale)));
uint8_t res = fx + 128;
return res;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion_int8(const Tin& x,
const float scale) {
return x;
}
// int8 -> half
// template <>
// __inline__ __device__ uint16_t scaled_vec_conversion_int8<uint16_t, uint8_t>(
// const uint8_t& a, const float scale) {
// float res = int8_to_float(a, scale);
// return float_to_half(res);
// // return half(a);__float2half
// }
// int8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion_int8<uint32_t, uint16_t>(
const uint16_t& a, const float scale) {
union {
uint8_t uint8[2];
uint16_t uint16;
};
uint16 = a;
float2 b;
b.x = (uint8[0] - 128) * scale;
b.y = (uint8[1] - 128) * scale;
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(b);
return uint32;
}
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
// int8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion_int8<uint2, uint32_t>(
const uint32_t& a, const float scale) {
union {
uint8_t uint8[4];
uint32_t uint32;
};
uint32 = a;
Float4_ b;
b.x.x = (uint8[0] - 128) * scale;
b.x.y = (uint8[1] - 128) * scale;
b.y.x = (uint8[2] - 128) * scale;
b.y.y = (uint8[3] - 128) * scale;
return vec_conversion<uint2, Float4_>(b);
}
inline __device__ float2 dequant(uint16_t a, const float scale)
{
union {
uint8_t uint8[2];
uint16_t uint16;
};
uint16 = a;
float2 b;
b.x = (uint8[0] - 128) * scale;
b.y = (uint8[1] - 128) * scale;
return b;
}
// int8x8 -> half2x4
template <>
__inline__ __device__ uint4
scaled_vec_conversion_int8<uint4, uint2>(const uint2& a, const float scale) {
// scaled_vec_conversion_int8<uint4, uint64_t>(const uint64_t& a, const float scale) {
union {
uint16_t uint16[4];
uint2 uint64;
};
uint64 = a;
Float8_ b;
b.x = dequant(uint16[0], scale);
b.y = dequant(uint16[1], scale);
b.z = dequant(uint16[2], scale);
b.w = dequant(uint16[3], scale);
uint4 c;
c.x = vec_conversion<uint32_t, float2>(b.x);
c.y = vec_conversion<uint32_t, float2>(b.y);
c.z = vec_conversion<uint32_t, float2>(b.z);
c.w = vec_conversion<uint32_t, float2>(b.w);
return c;
}
// int8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion_int8<__nv_bfloat16, uint8_t>(const uint8_t& a,
const float scale) {
// Note there is no direct convert function from int8 to bf16.
float res = int8_to_float(a, scale);
return __float2bfloat16(res);
}
// int8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion_int8<__nv_bfloat162, uint16_t>(const uint16_t& a,
const float scale) {
__nv_bfloat162 res;
res.x = scaled_vec_conversion_int8<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
res.y = scaled_vec_conversion_int8<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
scale);
return res;
}
// int8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion_int8<bf16_4_t, uint32_t>(
const uint32_t& a, const float scale) {
bf16_4_t res;
res.x =
scaled_vec_conversion_int8<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion_int8<__nv_bfloat162, uint16_t>(
(uint16_t)(a >> 16U), scale);
return res;
}
// int8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t
scaled_vec_conversion_int8<bf16_8_t, uint2>(const uint2& a, const float scale) {
// scaled_vec_conversion_int8<bf16_8_t, uint64_t>(const uint64_t& a, const float scale) {
// bf16_4_t tmp1, tmp2;
// tmp1 = scaled_vec_conversion_int8<bf16_4_t, uint32_t>(a.x, scale);
// tmp2 = scaled_vec_conversion_int8<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
return res;
}
// int8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion_int8<float, uint8_t>(
const uint8_t& a, const float scale) {
float res = int8_to_float(a, scale);
return res;
}
// int8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion_int8<float2, uint16_t>(
const uint16_t& a, const float scale) {
// int8x2 -> half2
uint32_t tmp = scaled_vec_conversion_int8<uint32_t, uint16_t>(a, scale);
// half2 -> float2
return half2_to_float2(tmp);
}
// int8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion_int8<Float4_, uint32_t>(
const uint32_t& a, const float scale) {
Float4_ res;
res.x = scaled_vec_conversion_int8<float2, uint16_t>((uint16_t)a, scale);
res.y =
scaled_vec_conversion_int8<float2, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// int8x8 -> float8
template <>
__inline__ __device__ Float8_
scaled_vec_conversion_int8<Float8_, uint64_t>(const uint64_t& a, const float scale) {
// scaled_vec_conversion_int8<Float8_, uint2>(const uint2& a, const float scale) {
// Float4_ tmp1, tmp2;
// tmp1 = scaled_vec_conversion_int8<Float4_, uint32_t>(a.x, scale);
// tmp2 = scaled_vec_conversion_int8<Float4_, uint32_t>(a.y, scale);
Float8_ res;
// res.x = tmp1.x;
// res.y = tmp1.y;
// res.z = tmp2.x;
// res.w = tmp2.y;
return res;
}
// half -> int8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion_int8<uint8_t, uint16_t>(
const uint16_t& a, const float scale) {
uint8_t res = float_to_int8(half_to_float(a), scale);
return (uint8_t)res;
// return (uint8_t)(a);
}
// bf16 -> int8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_int8<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a,
const float scale) {
uint8_t res = float_to_int8(__bfloat162float(a), scale);
return (uint8_t)res;
}
// float -> int8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_int8<uint8_t, float>(const float& a, const float scale) {
uint8_t res = float_to_int8(a, scale);
return (uint8_t)res;
// return (uint8_t)(a);
}
// int8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion_int8<float4, uint32_t>(
const uint32_t& a, const float scale) {
Float4_ tmp = scaled_vec_conversion_int8<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
} // namespace int8
} // namespace vllm
...@@ -131,7 +131,8 @@ class PagedAttention: ...@@ -131,7 +131,8 @@ class PagedAttention:
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
if (kv_cache_dtype == "int8"):
use_tc = False
if use_tc and head_size==128: if use_tc and head_size==128:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:") print("PA V1 SIZE:")
......
...@@ -1288,7 +1288,7 @@ class ModelConfig: ...@@ -1288,7 +1288,7 @@ class ModelConfig:
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "int8"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"] PrefixCachingHashAlgo = Literal["builtin", "sha256"]
......
...@@ -1383,6 +1383,11 @@ class EngineArgs: ...@@ -1383,6 +1383,11 @@ class EngineArgs:
from vllm.attention.utils.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
supported = flash_attn_supports_fp8() supported = flash_attn_supports_fp8()
int8_attention = self.kv_cache_dtype.startswith("int8")
if int8_attention:
supported = True
if not supported: if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype", _raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False) recommend_to_remove=False)
......
...@@ -747,6 +747,8 @@ def get_kv_cache_torch_dtype( ...@@ -747,6 +747,8 @@ def get_kv_cache_torch_dtype(
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8": elif cache_dtype == "fp8":
torch_dtype = torch.uint8 torch_dtype = torch.uint8
elif cache_dtype == "int8":
torch_dtype = torch.uint8
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype): elif isinstance(cache_dtype, torch.dtype):
...@@ -792,6 +794,8 @@ def create_kv_caches_with_random_flash( ...@@ -792,6 +794,8 @@ def create_kv_caches_with_random_flash(
key_value_cache.uniform_(-scale, scale) key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale) _generate_random_fp8(key_value_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(value_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -833,6 +837,8 @@ def create_kv_caches_with_random( ...@@ -833,6 +837,8 @@ def create_kv_caches_with_random(
key_cache.uniform_(-scale, scale) key_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(key_cache, -scale, scale) _generate_random_fp8(key_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(key_value_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -848,6 +854,8 @@ def create_kv_caches_with_random( ...@@ -848,6 +854,8 @@ def create_kv_caches_with_random(
value_cache.uniform_(-scale, scale) value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(value_cache, -scale, scale) _generate_random_fp8(value_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(key_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment