Commit b918400d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/vllm-0.8.5-zhangshao' into v0.8.5.post1-dev

parents 8fb5dea5 e02d110d
......@@ -487,7 +487,6 @@ __device__ void paged_attention_kernel_opt(
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
// printf("======xiabo_kvint8\n");
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
......
......@@ -5,6 +5,7 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
......@@ -88,6 +89,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0);
}
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
......@@ -95,12 +97,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{
half4_t data[2];
};
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t>
struct vec2data{
scalar_t data[2];
};
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
......@@ -165,7 +217,7 @@ __global__ void paged_attention_kernel_TC(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__)
......@@ -177,6 +229,7 @@ __global__ void paged_attention_kernel_TC(
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
......@@ -193,6 +246,8 @@ __global__ void paged_attention_kernel_TC(
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
......@@ -242,17 +297,19 @@ __global__ void paged_attention_kernel_TC(
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll
for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512);
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
}
......@@ -262,6 +319,23 @@ __global__ void paged_attention_kernel_TC(
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
}
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
......@@ -362,7 +436,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -458,7 +539,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -563,7 +651,14 @@ __global__ void paged_attention_kernel_TC(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -895,6 +990,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4;
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc(
......@@ -920,17 +1016,16 @@ void paged_attention_v2_launcher_opt_tc(
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
......@@ -943,7 +1038,7 @@ void paged_attention_v2_launcher_opt_tc(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
......
......@@ -5,6 +5,7 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/int8_kvcache/quant_utils.cuh"
#include <hip/hip_bf16.h>
......@@ -71,6 +72,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0);
}
using uint8x4_t = __attribute__( (__vector_size__(4 * sizeof(uint8_t)) )) uint8_t;
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
......@@ -78,12 +80,62 @@ using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{
half4_t data[2];
};
struct uint8x4x4{
uint8x4_t data[4];
};
template<typename scalar_t>
struct vec2data{
scalar_t data[2];
};
inline __device__ float uint82float(const uint8_t& input) {
const uint32_t w = (uint32_t)input << 24;
const uint32_t sign = w & UINT32_C(0x80000000);
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
uint32_t renorm_shift = __clz(nonsign);
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
return c10::detail::fp32_from_bits(result);
}
template<bool is_half,bool is_fp8>
inline __device__ half4_t int8x4_to_half4(uint8x4_t x, const float scale) {
half4_t ret;
if constexpr(is_fp8){
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=uint82float(x[i])*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16(uint82float(x[i])*scale);
}
}
}
else{
if constexpr(is_half){
#pragma unroll
for(int i=0;i<4;i++){
ret[i]=(x[i]-128.0f)*scale;
}
}
else{
__nv_bfloat16 *bd= reinterpret_cast<__nv_bfloat16 *>(&ret);
#pragma unroll
for(int i=0;i<4;i++){
bd[i]=__float2bfloat16((x[i]-128.0f)*scale);
}
}
}
return ret;
}
template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{
......@@ -148,7 +200,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0,int PARTITION_SIZE=0) {
......@@ -161,6 +213,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
......@@ -177,6 +230,8 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int lane = thread_idx % WARP_SIZE;
const int rowid = lane%16;
const int rows = lane/16;
const float k_scale=*k_scale_ptr;
const float v_scale=*v_scale_ptr;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
......@@ -226,17 +281,19 @@ __global__ void paged_attention_kernel_TC_with_mask(
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*x;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0};
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll
for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512);
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*WARP_SIZE*x);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
}
......@@ -246,6 +303,23 @@ __global__ void paged_attention_kernel_TC_with_mask(
builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
}
}
else{
uint8x4x4 k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+1];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
k_quant=*reinterpret_cast<const uint8x4x4*>(k_ptr+WARP_SIZE*x);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+8];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[0],k_scale),q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[1],k_scale),q_vec.data[1],qk_vec);
if(rowid<q_boundary)q_vec=q_vecs[rowid][2*rows+9];
builtin_amdgcn_mmac<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[2],k_scale),q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(int8x4_to_half4<is_half,is_fp8>(k_quant.data[3],k_scale),q_vec.data[1],qk_vec);
}
#pragma unroll
for(int i=0;i<reuse_group;i++){
int reuse_kv_idx=rows+i*4;
......@@ -353,7 +427,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -449,7 +530,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -554,7 +642,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
half4_t v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto){
v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
}
else {
uint8x4_t quant_v = *reinterpret_cast<const uint8x4_t*>(v_ptr + offset);
v_vec=int8x4_to_half4<is_half,is_fp8>(quant_v,v_scale);
}
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
......@@ -831,14 +926,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr;
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
......@@ -851,7 +944,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
......
......@@ -130,10 +130,7 @@ class PagedAttention:
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
kvquant = False
if (kv_cache_dtype == "int8"):
kvquant = True
if use_tc and head_size==128 and not kvquant:
if use_tc and head_size==128:
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
......
......@@ -77,6 +77,12 @@ logger = init_logger(__name__)
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
def _generate_random_int8(
tensor: torch.Tensor,
) -> None:
tensor = torch.randint(0, 255, tensor.size(),dtype=torch.uint8,device='cuda')
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/features/compatibility_matrix.md
......@@ -798,7 +804,7 @@ def create_kv_caches_with_random_flash(
elif cache_dtype == 'fp8':
_generate_random_fp8(key_value_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(value_cache)
_generate_random_int8(key_value_cache)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
......@@ -841,7 +847,7 @@ def create_kv_caches_with_random(
elif cache_dtype == 'fp8':
_generate_random_fp8(key_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(key_value_cache)
_generate_random_int8(key_cache)
else:
raise ValueError(
f"Does not support key cache of type {cache_dtype}")
......@@ -858,7 +864,7 @@ def create_kv_caches_with_random(
elif cache_dtype == 'fp8':
_generate_random_fp8(value_cache, -scale, scale)
elif cache_dtype == 'int8':
_generate_random_int8(key_cache)
_generate_random_int8(value_cache)
else:
raise ValueError(
f"Does not support value cache of type {cache_dtype}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment