Commit b918400d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/vllm-0.8.5-zhangshao' into v0.8.5.post1-dev

parents 8fb5dea5 e02d110d
...@@ -487,7 +487,6 @@ __device__ void paged_attention_kernel_opt( ...@@ -487,7 +487,6 @@ __device__ void paged_attention_kernel_opt(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) { } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
// printf("======xiabo_kvint8\n");
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -88,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -88,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
...@@ -95,12 +97,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float; ...@@ -95,12 +97,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{ struct half4x2{
half4_t data[2]; half4_t data[2];
}; };
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t> template<typename scalar_t>
struct vec2data{ struct vec2data{
scalar_t data[2]; scalar_t data[2];
}; };
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half> template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src) inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{ {
...@@ -165,7 +217,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -165,7 +217,7 @@ __global__ void paged_attention_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,int PARTITION_SIZE=0) { const int blocksparse_block_size, const int blocksparse_head_sliding_step,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
...@@ -177,6 +229,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -177,6 +229,7 @@ __global__ void paged_attention_kernel_TC(
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0; const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
...@@ -193,6 +246,8 @@ __global__ void paged_attention_kernel_TC( ...@@ -193,6 +246,8 @@ __global__ void paged_attention_kernel_TC(
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16; const int rowid = lane%16;
const int rows = lane/16; const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
...@@ -242,17 +297,19 @@ __global__ void paged_attention_kernel_TC( ...@@ -242,17 +297,19 @@ __global__ void paged_attention_kernel_TC(
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8; const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride; const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0}; float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2]; half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr); k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll #pragma unroll
for(int i=0;i<3;i++){ for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows]; if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512); k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
} }
...@@ -262,6 +319,23 @@ __global__ void paged_attention_kernel_TC( ...@@ -262,6 +319,23 @@ __global__ void paged_attention_kernel_TC(
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec); v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
} }
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll #pragma unroll
for(int i=0;i<reuse_group;i++){ for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
...@@ -362,7 +436,14 @@ __global__ void paged_attention_kernel_TC( ...@@ -362,7 +436,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -458,7 +539,14 @@ __global__ void paged_attention_kernel_TC( ...@@ -458,7 +539,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -563,7 +651,14 @@ __global__ void paged_attention_kernel_TC( ...@@ -563,7 +651,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -895,6 +990,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -895,6 +990,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4; if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4;
} }
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc( void paged_attention_v2_launcher_opt_tc(
...@@ -920,17 +1016,16 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -920,17 +1016,16 @@ void paged_attention_v2_launcher_opt_tc(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr; static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr; static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr; static T* tmp_out_ptr = nullptr;
...@@ -943,7 +1038,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -943,7 +1038,7 @@ void paged_attention_v2_launcher_opt_tc(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -71,6 +72,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -71,6 +72,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
...@@ -78,12 +80,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float; ...@@ -78,12 +80,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{ struct half4x2{
half4_t data[2]; half4_t data[2];
}; };
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t> template<typename scalar_t>
struct vec2data{ struct vec2data{
scalar_t data[2]; scalar_t data[2];
}; };
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half> template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src) inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{ {
...@@ -148,7 +200,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -148,7 +200,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0,int PARTITION_SIZE=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0,int PARTITION_SIZE=0) {
...@@ -161,6 +213,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -161,6 +213,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0; const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
...@@ -177,6 +230,8 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -177,6 +230,8 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16; const int rowid = lane%16;
const int rows = lane/16; const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
...@@ -226,17 +281,19 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -226,17 +281,19 @@ __global__ void paged_attention_kernel_TC_with_mask(
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8; const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) { for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride; const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0}; float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2]; half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr); k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll #pragma unroll
for(int i=0;i<3;i++){ for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows]; if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512); k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
} }
...@@ -246,6 +303,23 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -246,6 +303,23 @@ __global__ void paged_attention_kernel_TC_with_mask(
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec); v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
} }
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll #pragma unroll
for(int i=0;i<reuse_group;i++){ for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
...@@ -353,7 +427,14 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -353,7 +427,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -449,7 +530,14 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -449,7 +530,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -554,7 +642,14 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -554,7 +642,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll #pragma unroll
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
int offset=i*1024+k*256; int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset); half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
...@@ -831,14 +926,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -831,14 +926,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr()); const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr; static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr; static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr; static T* tmp_out_ptr = nullptr;
...@@ -851,7 +944,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -851,7 +944,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
......
...@@ -130,10 +130,7 @@ class PagedAttention: ...@@ -130,10 +130,7 @@ class PagedAttention:
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
kvquant = False if use_tc and head_size==128:
if (kv_cache_dtype == "int8"):
kvquant = True
if use_tc and head_size==128 and not kvquant:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:") print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
......
...@@ -77,6 +77,12 @@ logger = init_logger(__name__) ...@@ -77,6 +77,12 @@ logger = init_logger(__name__)
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('BW') SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
def _generate_random_int8(
tensor: torch.Tensor,
) -> None:
tensor = torch.randint(0, 255, tensor.size(),dtype=torch.uint8,device='cuda')
# Exception strings for non-implemented encoder/decoder scenarios # Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
...@@ -798,7 +804,7 @@ def create_kv_caches_with_random_flash( ...@@ -798,7 +804,7 @@ def create_kv_caches_with_random_flash(
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale) _generate_random_fp8(key_value_cache, -scale, scale)
elif cache_dtype == 'int8': elif cache_dtype == 'int8':
_generate_random_int8(value_cache) _generate_random_int8(key_value_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -841,7 +847,7 @@ def create_kv_caches_with_random( ...@@ -841,7 +847,7 @@ def create_kv_caches_with_random(
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(key_cache, -scale, scale) _generate_random_fp8(key_cache, -scale, scale)
elif cache_dtype == 'int8': elif cache_dtype == 'int8':
_generate_random_int8(key_value_cache) _generate_random_int8(key_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support key cache of type {cache_dtype}") f"Does not support key cache of type {cache_dtype}")
...@@ -858,7 +864,7 @@ def create_kv_caches_with_random( ...@@ -858,7 +864,7 @@ def create_kv_caches_with_random(
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
_generate_random_fp8(value_cache, -scale, scale) _generate_random_fp8(value_cache, -scale, scale)
elif cache_dtype == 'int8': elif cache_dtype == 'int8':
_generate_random_int8(key_cache) _generate_random_int8(value_cache)
else: else:
raise ValueError( raise ValueError(
f"Does not support value cache of type {cache_dtype}") f"Does not support value cache of type {cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment