Commit ea79ca42 authored by zhangshao's avatar zhangshao
Browse files

解决cudagraph模式下,小seq大batch PA变慢的bug

parent 82e8ca03
...@@ -19,6 +19,23 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -19,6 +19,23 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name() std::string get_device_name()
{ {
hipDeviceProp_t props{}; hipDeviceProp_t props{};
...@@ -230,6 +247,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -230,6 +247,7 @@ __global__ void paged_attention_kernel_TC(
if (partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3); constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
...@@ -292,7 +310,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -292,7 +310,7 @@ __global__ void paged_attention_kernel_TC(
} }
__syncthreads(); __syncthreads();
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem); ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS]; // __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
...@@ -350,11 +368,9 @@ __global__ void paged_attention_kernel_TC( ...@@ -350,11 +368,9 @@ __global__ void paged_attention_kernel_TC(
qk_vec[i] += alibi; qk_vec[i] += alibi;
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
else{ else{
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i]; from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
} }
} }
...@@ -387,15 +403,15 @@ __global__ void paged_attention_kernel_TC( ...@@ -387,15 +403,15 @@ __global__ void paged_attention_kernel_TC(
} }
qk_max_tmp = __shfl(qk_max_tmp, 0); qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp); float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val; from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum); exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum; from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
} }
if(USE_PARTITIONING&&thread_idx == 0){ if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp; max_out[reuse_kv_idx] = qk_max_tmp;
...@@ -423,10 +439,13 @@ __global__ void paged_attention_kernel_TC( ...@@ -423,10 +439,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){ if(rowid<4*q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -526,10 +545,13 @@ __global__ void paged_attention_kernel_TC( ...@@ -526,10 +545,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -638,10 +660,13 @@ __global__ void paged_attention_kernel_TC( ...@@ -638,10 +660,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -904,7 +929,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -904,7 +929,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
{ {
reusekv=1; reusekv=1;
num_thread=256; num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;} if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
...@@ -1037,10 +1061,12 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1037,10 +1061,12 @@ void paged_attention_v2_launcher_opt_tc(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){ if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE; PARTITION_SIZE=PA_PARTITION_SIZE;
...@@ -1055,7 +1081,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1055,7 +1081,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0; if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid; dim3 grid;
......
...@@ -19,6 +19,21 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -19,6 +19,21 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name(); std::string get_device_name();
static const std::string device_name=get_device_name(); static const std::string device_name=get_device_name();
...@@ -214,6 +229,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -214,6 +229,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
if (partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3); constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
...@@ -276,7 +292,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -276,7 +292,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
__syncthreads(); __syncthreads();
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem); ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS]; // __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
...@@ -341,11 +357,9 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -341,11 +357,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
else{ else{
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i]; from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
} }
} }
...@@ -378,15 +392,15 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -378,15 +392,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
qk_max_tmp = __shfl(qk_max_tmp, 0); qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp); float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val; from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum); exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum; from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
} }
if(USE_PARTITIONING&&thread_idx == 0){ if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp; max_out[reuse_kv_idx] = qk_max_tmp;
...@@ -414,10 +428,13 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -414,10 +428,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){ if(rowid<4*q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -517,10 +534,13 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -517,10 +534,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -629,10 +649,13 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -629,10 +649,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec); else{
for(int i=0;i<4;i++){ auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
from_float(p[i],f_logits[i]); scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
...@@ -943,10 +966,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -943,10 +966,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){ if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE; PARTITION_SIZE=PA_PARTITION_SIZE;
...@@ -961,7 +986,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -961,7 +986,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0; if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid; dim3 grid;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment