Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9d342d2
Unverified
Commit
d9d342d2
authored
Nov 26, 2025
by
Pleaplusone
Committed by
GitHub
Nov 26, 2025
Browse files
[Performance][MLA][ROCm] Remove redundant D2D copy in deepseek (#27457)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
53d7f1f6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
41 deletions
+49
-41
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+12
-15
csrc/ops.h
csrc/ops.h
+1
-2
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+1
-2
vllm/attention/ops/triton_merge_attn_states.py
vllm/attention/ops/triton_merge_attn_states.py
+17
-6
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+18
-16
No files found.
csrc/attention/merge_attn_states.cu
View file @
d9d342d2
...
@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
...
@@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
scalar_t
*
output
,
float
*
output_lse
,
const
scalar_t
*
prefix_output
,
scalar_t
*
output
,
float
*
output_lse
,
const
scalar_t
*
prefix_output
,
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
uint
head_size
)
{
const
uint
head_size
,
const
uint
prefix_head_stride
,
const
uint
output_head_stride
)
{
using
pack_128b_t
=
uint4
;
using
pack_128b_t
=
uint4
;
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
threads_per_head
=
head_size
/
pack_size
;
...
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
...
@@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
const
uint
head_idx
=
token_head_idx
%
num_heads
;
const
uint
head_idx
=
token_head_idx
%
num_heads
;
const
uint
pack_offset
=
pack_idx
*
pack_size
;
// (0~15)*8, etc.
const
uint
pack_offset
=
pack_idx
*
pack_size
;
// (0~15)*8, etc.
const
uint
head_offset
=
const
uint
src_head_offset
=
token_idx
*
num_heads
*
prefix_head_stride
+
token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
head_idx
*
prefix_head_stride
;
const
scalar_t
*
prefix_head_ptr
=
prefix_output
+
head_offset
;
const
uint
dst_head_offset
=
token_idx
*
num_heads
*
output_head_stride
+
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
head_offset
;
head_idx
*
output_head_stride
;
scalar_t
*
output_head_ptr
=
output
+
head_offset
;
const
scalar_t
*
prefix_head_ptr
=
prefix_output
+
src_head_offset
;
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
src_head_offset
;
scalar_t
*
output_head_ptr
=
output
+
dst_head_offset
;
float
p_lse
=
prefix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
p_lse
=
prefix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
...
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
...
@@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size
);
\
num_heads, head_size
, prefix_head_stride, output_head_stride);
\
}
}
/*@brief Merges the attention states from prefix and suffix
/*@brief Merges the attention states from prefix and suffix
...
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
...
@@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_heads
=
output
.
size
(
1
);
const
uint
num_heads
=
output
.
size
(
1
);
const
uint
head_size
=
output
.
size
(
2
);
const
uint
head_size
=
output
.
size
(
2
);
const
uint
prefix_head_stride
=
prefix_output
.
stride
(
1
);
const
uint
output_head_stride
=
output
.
stride
(
1
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
"headsize must be multiple of pack_size:"
,
pack_size
);
"headsize must be multiple of pack_size:"
,
pack_size
);
TORCH_CHECK
(
output
.
stride
(
-
2
)
==
head_size
&&
output
.
stride
(
-
1
)
==
1
,
"output heads must be contiguous in memory"
);
TORCH_CHECK
(
prefix_output
.
stride
(
-
2
)
==
head_size
&&
prefix_output
.
stride
(
-
1
)
==
1
,
"prefix_output heads must be contiguous in memory"
);
TORCH_CHECK
(
suffix_output
.
stride
(
-
2
)
==
head_size
&&
suffix_output
.
stride
(
-
1
)
==
1
,
"suffix_output heads must be contiguous in memory"
);
float
*
output_lse_ptr
=
nullptr
;
float
*
output_lse_ptr
=
nullptr
;
if
(
output_lse
.
has_value
())
{
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
...
...
csrc/ops.h
View file @
d9d342d2
...
@@ -52,14 +52,13 @@ void paged_attention_v2(
...
@@ -52,14 +52,13 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
#ifndef USE_ROCM
void
merge_attn_states
(
torch
::
Tensor
&
output
,
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
const
torch
::
Tensor
&
suffix_lse
);
#ifndef USE_ROCM
void
convert_vertical_slash_indexes
(
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...
...
csrc/torch_bindings.cpp
View file @
d9d342d2
...
@@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
#ifndef USE_ROCM
// Merge attn states
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
// can be used to combine partial attention results (in the split-KV case)
...
@@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()"
);
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#ifndef USE_ROCM
ops
.
def
(
ops
.
def
(
"convert_vertical_slash_indexes("
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! block_count, Tensor! block_offset, "
...
...
vllm/attention/ops/triton_merge_attn_states.py
View file @
d9d342d2
...
@@ -20,7 +20,11 @@ def merge_attn_states(
...
@@ -20,7 +20,11 @@ def merge_attn_states(
num_query_heads
=
output
.
shape
[
1
]
num_query_heads
=
output
.
shape
[
1
]
head_size
=
output
.
shape
[
2
]
head_size
=
output
.
shape
[
2
]
padded_head_size
=
triton
.
next_power_of_2
(
head_size
)
padded_head_size
=
triton
.
next_power_of_2
(
head_size
)
# We assume the output stride on num_head is not always as same as the
# `suffix_output` and `prefix_output`, as them might be padded by the attention
# backend.
prefix_head_stride
=
prefix_output
.
stride
(
1
)
output_head_stride
=
output
.
stride
(
1
)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel
[(
num_tokens
,
num_query_heads
)](
merge_attn_states_kernel
[(
num_tokens
,
num_query_heads
)](
output
,
output
,
...
@@ -29,6 +33,8 @@ def merge_attn_states(
...
@@ -29,6 +33,8 @@ def merge_attn_states(
prefix_lse
,
prefix_lse
,
suffix_output
,
suffix_output
,
suffix_lse
,
suffix_lse
,
prefix_head_stride
,
output_head_stride
,
head_size
,
head_size
,
padded_head_size
,
padded_head_size
,
output_lse
is
not
None
,
output_lse
is
not
None
,
...
@@ -43,6 +49,8 @@ def merge_attn_states_kernel(
...
@@ -43,6 +49,8 @@ def merge_attn_states_kernel(
prefix_lse
,
# [NUM_HEADS, NUM_TOKENS]
prefix_lse
,
# [NUM_HEADS, NUM_TOKENS]
suffix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse
,
# [NUM_HEADS, NUM_TOKENS]
suffix_lse
,
# [NUM_HEADS, NUM_TOKENS]
prefix_head_stride
,
output_head_stride
,
HEAD_SIZE
:
tl
.
constexpr
,
HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
...
@@ -79,15 +87,15 @@ def merge_attn_states_kernel(
...
@@ -79,15 +87,15 @@ def merge_attn_states_kernel(
head_mask
=
head_arange
<
HEAD_SIZE
head_mask
=
head_arange
<
HEAD_SIZE
p_out
=
tl
.
load
(
p_out
=
tl
.
load
(
prefix_output
prefix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
token_idx
*
num_heads
*
prefix_head_stride
+
head_idx
*
HEAD_SIZE
+
head_idx
*
prefix_head_stride
+
head_arange
,
+
head_arange
,
mask
=
head_mask
,
mask
=
head_mask
,
)
)
s_out
=
tl
.
load
(
s_out
=
tl
.
load
(
suffix_output
suffix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
token_idx
*
num_heads
*
prefix_head_stride
+
head_idx
*
HEAD_SIZE
+
head_idx
*
prefix_head_stride
+
head_arange
,
+
head_arange
,
mask
=
head_mask
,
mask
=
head_mask
,
)
)
...
@@ -99,7 +107,10 @@ def merge_attn_states_kernel(
...
@@ -99,7 +107,10 @@ def merge_attn_states_kernel(
s_scale
=
s_se
/
out_se
s_scale
=
s_se
/
out_se
out
=
p_out
*
p_scale
+
s_out
*
s_scale
out
=
p_out
*
p_scale
+
s_out
*
s_scale
tl
.
store
(
tl
.
store
(
output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
output
+
token_idx
*
num_heads
*
output_head_stride
+
head_idx
*
output_head_stride
+
head_arange
,
out
,
out
,
mask
=
head_mask
,
mask
=
head_mask
,
)
)
vllm/v1/attention/backends/mla/common.py
View file @
d9d342d2
...
@@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
...
@@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def
_v_up_proj
(
self
,
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
def
_v_up_proj
(
self
,
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
# Convert from (B, N, L) to (N, B, L)
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
if
self
.
is_aiter_triton_fp8_bmm_enabled
:
if
self
.
is_aiter_triton_fp8_bmm_enabled
:
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x
=
rocm_aiter_ops
.
triton_fp8_bmm
(
x
=
rocm_aiter_ops
.
triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
,
YQ
=
out
)
)
# Convert from (B, N, V) to (B, N * V)
x
=
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
# Copy result
out
.
copy_
(
x
)
else
:
else
:
# Convert from (B, N * V) to (N, B, V)
# Convert from (B, N * V) to (N, B, V)
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
).
transpose
(
0
,
1
)
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
).
transpose
(
0
,
1
)
...
@@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
:
torch
.
Tensor
,
)
->
None
:
# TODO (zyongye): Prefill function here
# TODO (zyongye): Prefill function here
assert
attn_metadata
.
prefill
is
not
None
assert
attn_metadata
.
prefill
is
not
None
assert
self
.
dcp_world_size
is
not
None
assert
self
.
dcp_world_size
is
not
None
...
@@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_run_prefill_new_tokens
(
output
_prefill
=
self
.
_run_prefill_new_tokens
(
prefill
=
attn_metadata
.
prefill
,
prefill
=
attn_metadata
.
prefill
,
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
)
if
has_context
:
if
has_context
:
suffix_output
,
suffix_lse
=
output
suffix_output
,
suffix_lse
=
output
_prefill
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
context_output
,
context_lse
=
(
context_output
,
context_lse
=
(
self
.
_context_parallel_compute_prefill_context
(
self
.
_context_parallel_compute_prefill_context
(
...
@@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
)
)
output
=
torch
.
empty_like
(
suffix_output
)
# unpad if necessary
if
self
.
_pad_v
:
context_output
=
context_output
[...,
:
v
.
shape
[
-
1
]]
suffix_output
=
suffix_output
[...,
:
v
.
shape
[
-
1
]]
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
)
merge_attn_states
(
merge_attn_states
(
output
=
output
,
output
=
output
,
prefix_output
=
context_output
,
prefix_output
=
context_output
,
...
@@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
suffix_output
=
suffix_output
,
suffix_output
=
suffix_output
,
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
else
:
# unpad if necessary
output_prefill
=
output_prefill
[...,
:
v
.
shape
[
-
1
]].
flatten
(
start_dim
=-
2
)
if
self
.
_pad_v
:
output
.
copy_
(
output_prefill
)
output
=
output
[...,
:
v
.
shape
[
-
1
]]
return
output
.
flatten
(
start_dim
=-
2
)
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
...
@@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
if
has_prefill
:
if
has_prefill
:
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
self
.
_forward_prefill
(
prefill_q
,
prefill_q
,
prefill_k_c_normed
,
prefill_k_c_normed
,
prefill_k_pe
,
prefill_k_pe
,
kv_cache
,
kv_cache
,
attn_metadata
,
attn_metadata
,
layer
.
_k_scale
,
layer
.
_k_scale
,
output
=
output
[
num_decode_tokens
:],
)
)
if
has_decode
:
if
has_decode
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment