Commit d074a953 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决大batch长seq崩溃问题

parent 04629132
...@@ -1054,9 +1054,9 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1054,9 +1054,9 @@ void paged_attention_v2_launcher_opt_tc(
static float* max_logits_ptr = nullptr; static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr; static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){ if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 10000000); // 10m hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 10000000); // 10m hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 400000000); // 400m hipMalloc(&tmp_out_ptr, 100000000); // 100m
} }
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -1183,4 +1183,4 @@ void paged_attention_v1_opt_tc( ...@@ -1183,4 +1183,4 @@ void paged_attention_v1_opt_tc(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -959,9 +959,9 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -959,9 +959,9 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
static float* max_logits_ptr = nullptr; static float* max_logits_ptr = nullptr;
static T* tmp_out_ptr = nullptr; static T* tmp_out_ptr = nullptr;
if(exp_sums_ptr == nullptr){ if(exp_sums_ptr == nullptr){
hipMalloc(&exp_sums_ptr, 10000000); // 10m hipMalloc(&exp_sums_ptr, 1000000); // 1m
hipMalloc(&max_logits_ptr, 10000000); // 10m hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 400000000); // 400m hipMalloc(&tmp_out_ptr, 100000000); // 100m
} }
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment