Commit 903593d3 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决pa v1 tc 部分 size bug

parent 2009d4a1
......@@ -302,6 +302,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
// if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
......@@ -401,33 +402,30 @@ __device__ void paged_attention_kernel_TC(
}
}
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
dst[row_idx] = accs[reuse_kv_idx][i];
if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
accs[reuse_kv_idx][i] += src[row_idx];
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
......@@ -487,37 +485,35 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
__syncthreads();
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
dst[row_idx] = accs[reuse_kv_idx][i];
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
}
}
__syncthreads();
if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
accs[reuse_kv_idx][i] += src[row_idx];
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
floatV_t tmp=out_smem[thread_idx];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] += tmp[i];
}
}
__syncthreads();
}
__syncthreads();
}
// Write the final output.
}
if (warp_idx == 0) {
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
......@@ -842,16 +838,10 @@ void paged_attention_v1_launcher_opt_tc(
NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
if(REUSE_KV_TIMES==1)outputs_size=0;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
if (NUM_WARPS==64)outputs_size=0;
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
......@@ -1054,7 +1044,7 @@ void paged_attention_v2_launcher_opt_tc(
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment