Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
12ef7e3b
Unverified
Commit
12ef7e3b
authored
Apr 16, 2025
by
DefTruth
Committed by
GitHub
Apr 15, 2025
Browse files
bugfix: fix merge_state_v2 cuda graph (#5419)
parent
838fa0f2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
12 deletions
+15
-12
sgl-kernel/csrc/attention/merge_attn_states.cu
sgl-kernel/csrc/attention/merge_attn_states.cu
+15
-12
No files found.
sgl-kernel/csrc/attention/merge_attn_states.cu
View file @
12ef7e3b
...
...
@@ -121,18 +121,18 @@ __global__ void merge_attn_states_kernel(
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \
reinterpret_cast<float*>(output_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
num_tokens, \
num_heads, \
head_size); \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS)
\
{
\
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block
, 0, stream
>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()),
\
reinterpret_cast<float*>(output_lse.data_ptr()),
\
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()),
\
reinterpret_cast<float*>(prefix_lse.data_ptr()),
\
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()),
\
reinterpret_cast<float*>(suffix_lse.data_ptr()),
\
num_tokens,
\
num_heads,
\
head_size);
\
}
/*@brief Merges the attention states from prefix and suffix
...
...
@@ -170,6 +170,9 @@ void merge_attn_states_launcher(
dim3
block
(
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
const
c10
::
cuda
::
OptionalCUDAGuard
device_guard
(
prefix_output
.
device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment