Commit bac73e5d authored by zhangshao's avatar zhangshao
Browse files

解决pa v1 tc 部分 size bug

parent 2c7f740a
...@@ -302,6 +302,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -302,6 +302,7 @@ __device__ void paged_attention_kernel_TC(
} }
} }
} }
// if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
...@@ -401,33 +402,30 @@ __device__ void paged_attention_kernel_TC( ...@@ -401,33 +402,30 @@ __device__ void paged_attention_kernel_TC(
} }
} }
__syncthreads(); __syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps. // Perform reduction across warps.
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
float* out_smem = reinterpret_cast<float*>(shared_mem); if constexpr (NUM_THREADS>64){
#pragma unroll floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
for (int i = NUM_WARPS; i > 1; i /= 2) { #pragma unroll
int mid = i / 2; for (int i = NUM_WARPS; i > 1; i /= 2) {
// Upper warps write to shared memory. int mid = i / 2;
if (warp_idx >= mid && warp_idx < i) { // Upper warps write to shared memory.
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+(warp_idx - mid) * HEAD_SIZE]; if (warp_idx >= mid && warp_idx < i) {
#pragma unroll out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE;
dst[row_idx] = accs[reuse_kv_idx][i];
} }
} __syncthreads();
__syncthreads(); // Lower warps update the output.
if (warp_idx < mid) {
// Lower warps update the output. floatV_t tmp=out_smem[thread_idx];
if (warp_idx < mid) { #pragma unroll
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+warp_idx * HEAD_SIZE]; for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll accs[reuse_kv_idx][i] += tmp[i];
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { }
const int row_idx = lane + i * WARP_SIZE;
accs[reuse_kv_idx][i] += src[row_idx];
} }
__syncthreads();
} }
__syncthreads();
} }
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
...@@ -487,37 +485,35 @@ __device__ void paged_attention_kernel_TC( ...@@ -487,37 +485,35 @@ __device__ void paged_attention_kernel_TC(
} }
} }
} }
} }
__syncthreads(); if constexpr (NUM_THREADS>64){
// Perform reduction across warps. __syncthreads();
for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) { using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
float* out_smem = reinterpret_cast<float*>(shared_mem); // Perform reduction across warps.
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) { for(int reuse_kv_idx=0; reuse_kv_idx<GROUPS; reuse_kv_idx++) {
int mid = i / 2;
// Upper warps write to shared memory. floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
if (warp_idx >= mid && warp_idx < i) { #pragma unroll
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+(warp_idx - mid) * HEAD_SIZE]; for (int i = NUM_WARPS; i > 1; i /= 2) {
#pragma unroll int mid = i / 2;
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { // Upper warps write to shared memory.
const int row_idx = lane + i * WARP_SIZE; if (warp_idx >= mid && warp_idx < i) {
dst[row_idx] = accs[reuse_kv_idx][i]; out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]);
} }
} __syncthreads();
__syncthreads(); // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE)+warp_idx * HEAD_SIZE]; floatV_t tmp=out_smem[thread_idx];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE; accs[reuse_kv_idx][i] += tmp[i];
accs[reuse_kv_idx][i] += src[row_idx]; }
} }
__syncthreads();
} }
__syncthreads();
} }
// Write the final output.
} }
if (warp_idx == 0) { if (warp_idx == 0) {
for(int g=0;g<reuse_group;g++){ for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows; int reusekvid=g*4+rows;
...@@ -842,16 +838,10 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -842,16 +838,10 @@ void paged_attention_v1_launcher_opt_tc(
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2; int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
if(REUSE_KV_TIMES==1)outputs_size=0; if (NUM_WARPS==64)outputs_size=0;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size); int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs); dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
...@@ -1054,7 +1044,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1054,7 +1044,7 @@ void paged_attention_v2_launcher_opt_tc(
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid; dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads; grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions; grid.y = max_num_partitions;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment