Commit ea79ca42 authored by zhangshao's avatar zhangshao
Browse files

解决cudagraph模式下,小seq大batch PA变慢的bug

parent 82e8ca03
......@@ -19,6 +19,23 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name()
{
hipDeviceProp_t props{};
......@@ -230,6 +247,7 @@ __global__ void paged_attention_kernel_TC(
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
......@@ -292,7 +310,7 @@ __global__ void paged_attention_kernel_TC(
}
__syncthreads();
extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem);
ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
......@@ -350,11 +368,9 @@ __global__ void paged_attention_kernel_TC(
qk_vec[i] += alibi;
}
const bool mask = (token_idx >= seq_len);
if(mask){
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
else{
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
......@@ -387,15 +403,15 @@ __global__ void paged_attention_kernel_TC(
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val;
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
......@@ -423,10 +439,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -526,10 +545,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -638,10 +660,13 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -904,7 +929,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
{
reusekv=1;
num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
......@@ -1037,10 +1061,12 @@ void paged_attention_v2_launcher_opt_tc(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
......@@ -1055,7 +1081,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
......
......@@ -19,6 +19,21 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template <bool>
struct AccType {};
template <>
struct AccType<true> {
using type = uint16_t;
};
template <>
struct AccType<false> {
using type = float;
};
template<bool is_half>
using __acc_type = typename AccType<is_half>::type;
std::string get_device_name();
static const std::string device_name=get_device_name();
......@@ -214,6 +229,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
constexpr bool is_fp8 = (KV_DTYPE==Fp8KVCacheDataType::kFp8E4M3);
using ACC_TYPE = __acc_type<is_half>;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
......@@ -276,7 +292,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
__syncthreads();
extern __shared__ char shared_mem[];
float* logits = reinterpret_cast<float*>(shared_mem);
ACC_TYPE* logits = reinterpret_cast<ACC_TYPE*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
......@@ -341,11 +357,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
const bool mask = (token_idx >= seq_len);
if(mask){
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
if(mask) from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
else{
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
......@@ -378,15 +392,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val;
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
......@@ -414,10 +428,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -517,10 +534,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -629,10 +649,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
if constexpr(is_half) logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
else{
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
......@@ -943,10 +966,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
constexpr bool is_half = std::is_same<T, uint16_t>::value;
using ACC_TYPE = __acc_type<is_half>;
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
......@@ -961,7 +986,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(ACC_TYPE);
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment