Commit d074a953 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决大batch长seq崩溃问题

parent 04629132
......@@ -1054,9 +1054,9 @@ void paged_attention_v2_launcher_opt_tc(
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 10000000); // 10m
hipMalloc(&max_logits_ptr, 10000000); // 10m
hipMalloc(&tmp_out_ptr, 400000000); // 400m
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......@@ -1183,4 +1183,4 @@ void paged_attention_v1_opt_tc(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
\ No newline at end of file
......@@ -959,9 +959,9 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 10000000); // 10m
hipMalloc(&max_logits_ptr, 10000000); // 10m
hipMalloc(&tmp_out_ptr, 400000000); // 400m
hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment