"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5b1af9ab8244484c5fc28a89fdbb1557106d897f"
Unverified Commit 12ef7e3b authored by DefTruth's avatar DefTruth Committed by GitHub
Browse files

bugfix: fix merge_state_v2 cuda graph (#5419)

parent 838fa0f2
...@@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel( ...@@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel(
} \ } \
} }
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \ { \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \ merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \ reinterpret_cast<scalar_t*>(output.data_ptr()), \
reinterpret_cast<float*>(output_lse.data_ptr()), \ reinterpret_cast<float*>(output_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \ reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \ reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \ reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), \ reinterpret_cast<float*>(suffix_lse.data_ptr()), \
num_tokens, \ num_tokens, \
num_heads, \ num_heads, \
head_size); \ head_size); \
} }
/*@brief Merges the attention states from prefix and suffix /*@brief Merges the attention states from prefix and suffix
...@@ -170,6 +170,9 @@ void merge_attn_states_launcher( ...@@ -170,6 +170,9 @@ void merge_attn_states_launcher(
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment