Commit 259da693 authored by zhangshao's avatar zhangshao
Browse files

增加4倍的临时空间,解决大batch 大seq时,临时空间不足访存越界的情况

parent e7a963f5
......@@ -918,8 +918,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 128>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS>), \
dim3(reduce_grid), dim3(NUM_THREADS), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
......@@ -1054,9 +1054,9 @@ void paged_attention_v2_launcher_opt_tc(
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
hipMalloc(&exp_sums_ptr, 10000000); // 10m
hipMalloc(&max_logits_ptr, 10000000); // 10m
hipMalloc(&tmp_out_ptr, 400000000); // 400m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......@@ -1066,7 +1066,7 @@ void paged_attention_v2_launcher_opt_tc(
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
......@@ -1076,7 +1076,6 @@ void paged_attention_v2_launcher_opt_tc(
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_USE_V1!=0)max_num_partitions=1;
if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
assert(num_seqs*num_heads*max_num_partitions*head_size<=100000000);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
......
......@@ -907,8 +907,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 128>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS>), \
dim3(reduce_grid), dim3(NUM_THREADS), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
......@@ -959,9 +959,9 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
hipMalloc(&exp_sums_ptr, 10000000); // 10m
hipMalloc(&max_logits_ptr, 10000000); // 10m
hipMalloc(&tmp_out_ptr, 400000000); // 400m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......@@ -971,7 +971,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2){
constexpr int HEAD_SIZE=128;
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half)PARTITION_SIZE=256;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
......@@ -981,7 +981,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_USE_V1!=0)max_num_partitions=1;
if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
assert(num_seqs*num_heads*max_num_partitions*head_size<=100000000);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment