Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b8d795b
Commit
2b8d795b
authored
Sep 11, 2025
by
zhuwenwen
Browse files
add rocm merge_attn_states
parent
f572ca96
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
3 deletions
+5
-3
csrc/ops.h
csrc/ops.h
+2
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
vllm/attention/ops/merge_attn_states.py
vllm/attention/ops/merge_attn_states.py
+1
-1
No files found.
csrc/ops.h
View file @
2b8d795b
...
...
@@ -172,7 +172,6 @@ void paged_attention_v2_opt_tc_with_mask(
const
int64_t
attn_masks_stride
=
0
);
#ifndef USE_ROCM
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
...
...
@@ -180,6 +179,8 @@ void merge_attn_states(torch::Tensor& output,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
#ifndef USE_ROCM
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...
...
csrc/torch_bindings.cpp
View file @
2b8d795b
...
...
@@ -220,7 +220,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int attn_masks_stride) -> ()"
);
ops
.
impl
(
"paged_attention_v2_opt_tc_with_mask"
,
torch
::
kCUDA
,
&
paged_attention_v2_opt_tc_with_mask
);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
...
...
@@ -234,6 +233,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#ifndef USE_ROCM
ops
.
def
(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
...
...
vllm/attention/ops/merge_attn_states.py
View file @
2b8d795b
...
...
@@ -31,7 +31,7 @@ def merge_attn_states(
return
headdim
%
4
==
0
return
headdim
%
8
==
0
if
(
current_platform
.
is_cuda
()
and
supported_dtypes
(
output
)
if
(
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
()
and
supported_dtypes
(
output
)
and
supported_headdim
(
output
)):
from
vllm._custom_ops
import
merge_attn_states
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment