Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e82ee40d
Unverified
Commit
e82ee40d
authored
Apr 16, 2025
by
DefTruth
Committed by
GitHub
Apr 16, 2025
Browse files
[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel (#16693)
Signed-off-by:
DefTruth
<
qiustudent_r@163.com
>
parent
facbe2a1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
10 deletions
+15
-10
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+15
-10
No files found.
csrc/attention/merge_attn_states.cu
View file @
e82ee40d
...
@@ -107,7 +107,8 @@ __global__ void merge_attn_states_kernel(
...
@@ -107,7 +107,8 @@ __global__ void merge_attn_states_kernel(
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
...
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
...
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,
d
] The log-sum-exp values for the prefix attention
* @param prefix_lse [h,
n
] The log-sum-exp values for the prefix attention
* states.
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,
d
] The log-sum-exp values for the suffix attention
* @param suffix_lse [h,
n
] The log-sum-exp values for the suffix attention
* states.
* states.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
...
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
if
(
output_lse
.
has_value
())
{
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
}
}
// process one pack elements per thread. float -> 4, half/bf16 -> 8
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
total_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
const
uint
total_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
const
c10
::
cuda
::
OptionalCUDAGuard
device_guard
(
prefix_output
.
device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment