Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
598190aa
Unverified
Commit
598190aa
authored
Mar 31, 2026
by
Olya Kozlova
Committed by
GitHub
Mar 31, 2026
Browse files
[fix] Remove trtllm ragged mla prefills (#36540)
Signed-off-by:
Olya Kozlova
<
okozlova@nvidia.com
>
parent
b779eb33
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
185 additions
and
35 deletions
+185
-35
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+48
-19
csrc/ops.h
csrc/ops.h
+5
-6
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
tests/kernels/attention/test_merge_attn_states.py
tests/kernels/attention/test_merge_attn_states.py
+26
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+8
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+14
-2
vllm/v1/attention/ops/merge_attn_states.py
vllm/v1/attention/ops/merge_attn_states.py
+43
-2
vllm/v1/attention/ops/triton_merge_attn_states.py
vllm/v1/attention/ops/triton_merge_attn_states.py
+39
-2
No files found.
csrc/attention/merge_attn_states.cu
View file @
598190aa
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <algorithm>
#include <limits>
#include "attention_dtypes.h"
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "attention_utils.cuh"
...
@@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel(
...
@@ -17,7 +18,7 @@ __global__ void merge_attn_states_kernel(
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
uint
head_size
,
const
uint
prefix_head_stride
,
const
uint
head_size
,
const
uint
prefix_head_stride
,
const
uint
output_head_stride
)
{
const
uint
output_head_stride
,
const
uint
prefix_num_tokens
)
{
using
pack_128b_t
=
uint4
;
using
pack_128b_t
=
uint4
;
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
threads_per_head
=
head_size
/
pack_size
;
...
@@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel(
...
@@ -43,6 +44,22 @@ __global__ void merge_attn_states_kernel(
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
src_head_offset
;
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
src_head_offset
;
scalar_t
*
output_head_ptr
=
output
+
dst_head_offset
;
scalar_t
*
output_head_ptr
=
output
+
dst_head_offset
;
// If token_idx >= prefix_num_tokens, just copy from suffix
if
(
token_idx
>=
prefix_num_tokens
)
{
if
(
pack_offset
<
head_size
)
{
pack_128b_t
s_out_pack
=
reinterpret_cast
<
const
pack_128b_t
*>
(
suffix_head_ptr
)[
pack_offset
/
pack_size
];
reinterpret_cast
<
pack_128b_t
*>
(
output_head_ptr
)[
pack_offset
/
pack_size
]
=
s_out_pack
;
}
if
(
output_lse
!=
nullptr
&&
pack_idx
==
0
)
{
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
output_lse
[
head_idx
*
num_tokens
+
token_idx
]
=
s_lse
;
}
return
;
}
// For tokens within prefix range, merge prefix and suffix
float
p_lse
=
prefix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
p_lse
=
prefix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
p_lse
=
std
::
isinf
(
p_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
p_lse
;
p_lse
=
std
::
isinf
(
p_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
p_lse
;
...
@@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel(
...
@@ -143,7 +160,8 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens); \
}
}
/*@brief Merges the attention states from prefix and suffix
/*@brief Merges the attention states from prefix and suffix
...
@@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel(
...
@@ -157,14 +175,18 @@ __global__ void merge_attn_states_kernel(
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
* states.
* @param prefill_tokens_with_context Number of prefill tokens with context
* For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
merge_attn_states_launcher
(
torch
::
Tensor
&
output
,
void
merge_attn_states_launcher
(
std
::
optional
<
torch
::
Tensor
>
output_lse
,
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
std
::
optional
<
int64_t
>
prefill_tokens_with_context
)
{
const
torch
::
Tensor
&
suffix_lse
)
{
constexpr
uint
NUM_THREADS
=
128
;
constexpr
uint
NUM_THREADS
=
128
;
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_heads
=
output
.
size
(
1
);
const
uint
num_heads
=
output
.
size
(
1
);
...
@@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
...
@@ -174,6 +196,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
"headsize must be multiple of pack_size:"
,
pack_size
);
"headsize must be multiple of pack_size:"
,
pack_size
);
const
uint
prefix_num_tokens
=
prefill_tokens_with_context
.
has_value
()
?
static_cast
<
uint
>
(
prefill_tokens_with_context
.
value
())
:
num_tokens
;
TORCH_CHECK
(
prefix_num_tokens
<=
num_tokens
,
"prefix_num_tokens must be <= num_tokens"
);
float
*
output_lse_ptr
=
nullptr
;
float
*
output_lse_ptr
=
nullptr
;
if
(
output_lse
.
has_value
())
{
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
...
@@ -194,16 +224,15 @@ void merge_attn_states_launcher(torch::Tensor& output,
...
@@ -194,16 +224,15 @@ void merge_attn_states_launcher(torch::Tensor& output,
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
{ \
merge_attn_states_launcher<scalar_t>(
output, output_lse, prefix_output,
\
merge_attn_states_launcher<scalar_t>(
\
prefix_lse, suffix_output,
\
output, output_lse, prefix_output,
prefix_lse, suffix_output, \
suffix_lse);
\
suffix_lse, prefill_tokens_with_context);
\
}
}
void
merge_attn_states
(
torch
::
Tensor
&
output
,
void
merge_attn_states
(
std
::
optional
<
torch
::
Tensor
>
output_lse
,
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
,
const
torch
::
Tensor
&
suffix_output
,
std
::
optional
<
int64_t
>
prefill_tokens_with_context
=
std
::
nullopt
)
{
const
torch
::
Tensor
&
suffix_lse
)
{
DISPATCH_BY_SCALAR_DTYPE
(
output
.
dtype
(),
CALL_MERGE_ATTN_STATES_LAUNCHER
);
DISPATCH_BY_SCALAR_DTYPE
(
output
.
dtype
(),
CALL_MERGE_ATTN_STATES_LAUNCHER
);
}
}
csrc/ops.h
View file @
598190aa
...
@@ -53,12 +53,11 @@ void paged_attention_v2(
...
@@ -53,12 +53,11 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
void
merge_attn_states
(
torch
::
Tensor
&
output
,
void
merge_attn_states
(
std
::
optional
<
torch
::
Tensor
>
output_lse
,
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
std
::
optional
<
int64_t
>
prefill_tokens_with_context
);
const
torch
::
Tensor
&
suffix_lse
);
#ifndef USE_ROCM
#ifndef USE_ROCM
void
convert_vertical_slash_indexes
(
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
...
...
csrc/torch_bindings.cpp
View file @
598190aa
...
@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor prefix_output,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()"
);
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#ifndef USE_ROCM
#ifndef USE_ROCM
ops
.
def
(
ops
.
def
(
...
...
tests/kernels/attention/test_merge_attn_states.py
View file @
598190aa
...
@@ -20,7 +20,11 @@ def merge_attn_states_torch(
...
@@ -20,7 +20,11 @@ def merge_attn_states_torch(
suffix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse
:
torch
.
Tensor
,
# [NUM_HEADS, NUM_TOKENS]
suffix_lse
:
torch
.
Tensor
,
# [NUM_HEADS, NUM_TOKENS]
output_lse
:
torch
.
Tensor
|
None
=
None
,
# [NUM_HEADS, NUM_TOKENS]
output_lse
:
torch
.
Tensor
|
None
=
None
,
# [NUM_HEADS, NUM_TOKENS]
prefill_tokens_with_context
:
int
|
None
=
None
,
):
):
# Apply prefill_tokens_with_context mask if needed
if
prefill_tokens_with_context
is
None
:
prefill_tokens_with_context
=
output
.
shape
[
0
]
p_lse
=
prefix_lse
p_lse
=
prefix_lse
s_lse
=
suffix_lse
s_lse
=
suffix_lse
# inf -> -inf
# inf -> -inf
...
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
...
@@ -28,6 +32,9 @@ def merge_attn_states_torch(
s_lse
[
s_lse
==
torch
.
inf
]
=
-
torch
.
inf
s_lse
[
s_lse
==
torch
.
inf
]
=
-
torch
.
inf
# max_lse [NUM_HEADS, NUM_TOKENS]
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse
=
torch
.
maximum
(
p_lse
,
s_lse
)
max_lse
=
torch
.
maximum
(
p_lse
,
s_lse
)
mask
=
torch
.
ones
((
prefix_lse
.
shape
[
1
],
1
,
1
),
device
=
p_lse
.
device
)
mask
[
prefill_tokens_with_context
:].
fill_
(
0
)
p_lse
=
p_lse
-
max_lse
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
p_lse_exp
=
torch
.
exp
(
p_lse
)
p_lse_exp
=
torch
.
exp
(
p_lse
)
...
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
...
@@ -35,11 +42,16 @@ def merge_attn_states_torch(
out_se
=
p_lse_exp
+
s_lse_exp
out_se
=
p_lse_exp
+
s_lse_exp
if
output_lse
is
not
None
:
if
output_lse
is
not
None
:
output_lse
=
torch
.
log
(
out_se
)
+
max_lse
output_lse
=
torch
.
log
(
out_se
)
+
max_lse
output_lse
[
prefill_tokens_with_context
:]
=
suffix_lse
[
prefill_tokens_with_context
:
]
p_scale
=
p_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
p_scale
=
p_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
s_scale
=
s_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
s_scale
=
s_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
p_scale
=
torch
.
transpose
(
p_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
p_scale
=
torch
.
transpose
(
p_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
s_scale
=
torch
.
transpose
(
s_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
s_scale
=
torch
.
transpose
(
s_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
output
=
prefix_output
*
p_scale
+
suffix_output
*
s_scale
output
.
copy_
(
prefix_output
*
p_scale
*
mask
+
suffix_output
*
(
s_scale
*
mask
+
(
1
-
mask
))
)
return
output
,
output_lse
return
output
,
output_lse
...
@@ -90,13 +102,18 @@ def generate_markdown_table():
...
@@ -90,13 +102,18 @@ def generate_markdown_table():
)
)
@
pytest
.
mark
.
parametrize
(
"prefill_tokens_with_context"
,
[
None
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_BATCH_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_BATCH_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_query_heads"
,
NUM_QUERY_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_query_heads"
,
NUM_QUERY_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
DTYPES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_merge_attn_states
(
def
test_merge_attn_states
(
num_tokens
:
int
,
num_query_heads
:
int
,
head_size
:
int
,
output_dtype
:
torch
.
dtype
prefill_tokens_with_context
:
int
|
None
,
num_tokens
:
int
,
num_query_heads
:
int
,
head_size
:
int
,
output_dtype
:
torch
.
dtype
,
):
):
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
pytest
.
skip
(
...
@@ -111,6 +128,7 @@ def test_merge_attn_states(
...
@@ -111,6 +128,7 @@ def test_merge_attn_states(
print
(
print
(
f
"
\n
NUM_TOKENS:
{
NUM_TOKENS
}
, NUM_HEADS:
{
NUM_HEADS
}
, "
f
"
\n
NUM_TOKENS:
{
NUM_TOKENS
}
, NUM_HEADS:
{
NUM_HEADS
}
, "
f
"HEAD_SIZE:
{
HEAD_SIZE
}
, DTYPE:
{
output_dtype
}
, "
f
"HEAD_SIZE:
{
HEAD_SIZE
}
, DTYPE:
{
output_dtype
}
, "
f
"prefill_tokens_with_context:
{
prefill_tokens_with_context
}
, "
f
"Device:
{
current_platform
.
get_device_name
()
}
"
f
"Device:
{
current_platform
.
get_device_name
()
}
"
)
)
...
@@ -164,6 +182,7 @@ def test_merge_attn_states(
...
@@ -164,6 +182,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse_torch
,
suffix_lse_torch
,
output_lse_torch
,
output_lse_torch
,
prefill_tokens_with_context
,
)
)
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
@@ -176,6 +195,7 @@ def test_merge_attn_states(
...
@@ -176,6 +195,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse_torch
,
suffix_lse_torch
,
output_lse_torch
,
output_lse_torch
,
prefill_tokens_with_context
,
)
)
end
.
record
()
end
.
record
()
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
@@ -199,6 +219,7 @@ def test_merge_attn_states(
...
@@ -199,6 +219,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse
,
suffix_lse
,
output_lse_ref_triton
,
output_lse_ref_triton
,
prefill_tokens_with_context
,
)
)
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
@@ -211,6 +232,7 @@ def test_merge_attn_states(
...
@@ -211,6 +232,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse
,
suffix_lse
,
output_lse_ref_triton
,
output_lse_ref_triton
,
prefill_tokens_with_context
,
)
)
end
.
record
()
end
.
record
()
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
@@ -231,6 +253,7 @@ def test_merge_attn_states(
...
@@ -231,6 +253,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse
,
suffix_lse
,
output_lse_cuda
,
output_lse_cuda
,
prefill_tokens_with_context
,
)
)
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
@@ -243,6 +266,7 @@ def test_merge_attn_states(
...
@@ -243,6 +266,7 @@ def test_merge_attn_states(
suffix_output
,
suffix_output
,
suffix_lse
,
suffix_lse
,
output_lse_cuda
,
output_lse_cuda
,
prefill_tokens_with_context
,
)
)
end
.
record
()
end
.
record
()
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
...
...
vllm/_custom_ops.py
View file @
598190aa
...
@@ -264,9 +264,16 @@ def merge_attn_states(
...
@@ -264,9 +264,16 @@ def merge_attn_states(
suffix_output
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
prefill_tokens_with_context
:
int
|
None
=
None
,
)
->
None
:
)
->
None
:
torch
.
ops
.
_C
.
merge_attn_states
(
torch
.
ops
.
_C
.
merge_attn_states
(
output
,
output_lse
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
output
,
output_lse
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
prefill_tokens_with_context
,
)
)
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
598190aa
...
@@ -1181,6 +1181,7 @@ class MLACommonPrefillMetadata:
...
@@ -1181,6 +1181,7 @@ class MLACommonPrefillMetadata:
padded_local_cu_seq_lens
:
torch
.
Tensor
|
None
=
None
padded_local_cu_seq_lens
:
torch
.
Tensor
|
None
=
None
cu_seq_lens_lst
:
list
[
list
[
int
]]
|
None
=
None
cu_seq_lens_lst
:
list
[
list
[
int
]]
|
None
=
None
chunk_size
:
int
|
None
=
None
chunk_size
:
int
|
None
=
None
prefill_tokens_with_context
:
int
|
None
=
None
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
...
@@ -1743,6 +1744,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1743,6 +1744,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill_query_start_loc
=
(
prefill_query_start_loc
=
(
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
]
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
]
)
)
prefill_query_start_loc_cpu
=
(
query_start_loc_cpu
[
reqs_start
:]
-
query_start_loc_cpu
[
reqs_start
]
)
chunked_context_metadata
=
None
chunked_context_metadata
=
None
if
max_context_len_cpu
>
0
:
if
max_context_len_cpu
>
0
:
...
@@ -1864,6 +1868,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1864,6 +1868,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if
self
.
_use_cudnn_prefill
if
self
.
_use_cudnn_prefill
else
MLACommonPrefillMetadata
.
ChunkedContextMetadata
else
MLACommonPrefillMetadata
.
ChunkedContextMetadata
)
)
prefill_tokens_with_context
=
None
if
num_prefills_with_context_cpu
>
0
:
prefill_tokens_with_context
=
prefill_query_start_loc_cpu
[
num_prefills_with_context_cpu
].
item
()
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
chunked_context_metadata
=
chunked_context_metadata_cls
(
chunked_context_metadata
=
chunked_context_metadata_cls
(
cu_seq_lens
=
cu_seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
),
cu_seq_lens
=
cu_seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
),
...
@@ -1883,6 +1892,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1883,6 +1892,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
),
),
cu_seq_lens_lst
=
cu_seq_lens_cpu
.
tolist
(),
cu_seq_lens_lst
=
cu_seq_lens_cpu
.
tolist
(),
chunk_size
=
padded_local_max_context_chunk_across_ranks
,
chunk_size
=
padded_local_max_context_chunk_across_ranks
,
prefill_tokens_with_context
=
prefill_tokens_with_context
,
)
)
else
:
else
:
chunked_context_metadata
=
chunked_context_metadata_cls
(
chunked_context_metadata
=
chunked_context_metadata_cls
(
...
@@ -1896,6 +1906,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1896,6 +1906,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
),
),
chunk_total_token
=
chunk_total_token
,
chunk_total_token
=
chunk_total_token
,
workspace
=
self
.
chunked_prefill_workspace
,
workspace
=
self
.
chunked_prefill_workspace
,
prefill_tokens_with_context
=
prefill_tokens_with_context
,
)
)
if
self
.
_use_cudnn_prefill
:
if
self
.
_use_cudnn_prefill
:
...
@@ -2382,14 +2393,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2382,14 +2393,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
assert
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
]
is
not
None
assert
prefill
.
chunked_context
.
seq_lens
[
chunk_idx
]
is
not
None
assert
prefill
.
workspace_buffer
is
not
None
assert
prefill
.
workspace_buffer
is
not
None
out
=
torch
.
zeros
(
out
=
torch
.
empty
(
q
.
shape
[
0
],
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
1
],
v
.
shape
[
2
],
v
.
shape
[
2
],
device
=
q
.
device
,
device
=
q
.
device
,
dtype
=
prefill
.
output_dtype
,
dtype
=
prefill
.
output_dtype
,
)
)
prefill
.
workspace_buffer
.
fill_
(
0
)
attn_out
,
lse
=
trtllm_ragged_attention_deepseek
(
attn_out
,
lse
=
trtllm_ragged_attention_deepseek
(
query
=
q
,
query
=
q
,
...
@@ -2691,6 +2701,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2691,6 +2701,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
)
)
if
has_context
:
if
has_context
:
assert
prefill_metadata
.
chunked_context
is
not
None
suffix_output
,
suffix_lse
=
output_prefill
suffix_output
,
suffix_lse
=
output_prefill
if
self
.
dcp_world_size
>
1
:
if
self
.
dcp_world_size
>
1
:
context_output
,
context_lse
=
(
context_output
,
context_lse
=
(
...
@@ -2719,6 +2730,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -2719,6 +2730,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefix_lse
=
context_lse
,
prefix_lse
=
context_lse
,
suffix_output
=
suffix_output
,
suffix_output
=
suffix_output
,
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
prefill_tokens_with_context
=
prefill_metadata
.
chunked_context
.
prefill_tokens_with_context
,
)
)
else
:
else
:
output_prefill
=
output_prefill
[...,
:
v
.
shape
[
-
1
]].
flatten
(
start_dim
=-
2
)
output_prefill
=
output_prefill
[...,
:
v
.
shape
[
-
1
]].
flatten
(
start_dim
=-
2
)
...
...
vllm/v1/attention/ops/merge_attn_states.py
View file @
598190aa
...
@@ -13,7 +13,36 @@ def merge_attn_states(
...
@@ -13,7 +13,36 @@ def merge_attn_states(
suffix_output
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
prefill_tokens_with_context
:
int
|
None
=
None
,
)
->
None
:
)
->
None
:
"""Merge partial attention outputs from prefix (KV cache) and suffix
(new tokens) into a single output tensor using the log-sum-exp (LSE)
rescaling method described in section 2.2 of
https://www.arxiv.org/pdf/2501.01005.
For tokens that have prefix context (token index < prefill_tokens_with_context),
the prefix and suffix partial outputs are combined as a weighted sum.
For tokens without prefix context, the suffix output is copied directly.
Args:
output: Output tensor of shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
prefix_output: Partial attention output over the prefix (KV cache),
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
prefix_lse: Log-sum-exp values for the prefix attention,
shape [NUM_HEADS, NUM_TOKENS].
suffix_output: Partial attention output over the suffix (new KV),
shape [NUM_TOKENS, NUM_HEADS, HEAD_SIZE].
suffix_lse: Log-sum-exp values for the suffix attention,
shape [NUM_HEADS, NUM_TOKENS].
output_lse: Optional tensor to store the merged LSE values,
shape [NUM_HEADS, NUM_TOKENS]. If None, LSE is not written out.
prefill_tokens_with_context: Number of prefill tokens that have
prefix context and therefore require merging. Tokens at indices
>= this value are decode or context-free prefill tokens whose
output is taken directly from suffix_output. If None, all tokens
are treated as having context.
"""
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# does not support FP8 dtype, fallback to use Triton kernel.
# does not support FP8 dtype, fallback to use Triton kernel.
def
supported_dtypes
(
o
:
torch
.
Tensor
)
->
bool
:
def
supported_dtypes
(
o
:
torch
.
Tensor
)
->
bool
:
...
@@ -37,11 +66,23 @@ def merge_attn_states(
...
@@ -37,11 +66,23 @@ def merge_attn_states(
from
vllm._custom_ops
import
merge_attn_states
from
vllm._custom_ops
import
merge_attn_states
return
merge_attn_states
(
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
,
prefill_tokens_with_context
,
)
)
else
:
else
:
from
vllm.v1.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.v1.attention.ops.triton_merge_attn_states
import
merge_attn_states
return
merge_attn_states
(
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
,
prefill_tokens_with_context
,
)
)
vllm/v1/attention/ops/triton_merge_attn_states.py
View file @
598190aa
...
@@ -15,6 +15,7 @@ def merge_attn_states(
...
@@ -15,6 +15,7 @@ def merge_attn_states(
suffix_output
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
output_lse
:
torch
.
Tensor
|
None
=
None
,
prefill_tokens_with_context
:
int
|
None
=
None
,
)
->
None
:
)
->
None
:
num_tokens
=
output
.
shape
[
0
]
num_tokens
=
output
.
shape
[
0
]
num_query_heads
=
output
.
shape
[
1
]
num_query_heads
=
output
.
shape
[
1
]
...
@@ -25,6 +26,11 @@ def merge_attn_states(
...
@@ -25,6 +26,11 @@ def merge_attn_states(
# backend.
# backend.
prefix_head_stride
=
prefix_output
.
stride
(
1
)
prefix_head_stride
=
prefix_output
.
stride
(
1
)
output_head_stride
=
output
.
stride
(
1
)
output_head_stride
=
output
.
stride
(
1
)
# If prefill_tokens_with_context is None, all tokens should use prefix context
if
prefill_tokens_with_context
is
None
:
prefill_tokens_with_context
=
num_tokens
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel
[(
num_tokens
,
num_query_heads
)](
merge_attn_states_kernel
[(
num_tokens
,
num_query_heads
)](
output
,
output
,
...
@@ -38,6 +44,7 @@ def merge_attn_states(
...
@@ -38,6 +44,7 @@ def merge_attn_states(
head_size
,
head_size
,
padded_head_size
,
padded_head_size
,
output_lse
is
not
None
,
output_lse
is
not
None
,
prefill_tokens_with_context
,
)
)
...
@@ -54,12 +61,44 @@ def merge_attn_states_kernel(
...
@@ -54,12 +61,44 @@ def merge_attn_states_kernel(
HEAD_SIZE
:
tl
.
constexpr
,
HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
prefill_tokens_with_context
:
tl
.
constexpr
,
):
):
token_idx
=
tl
.
program_id
(
0
)
token_idx
=
tl
.
program_id
(
0
)
num_tokens
=
tl
.
num_programs
(
0
)
num_tokens
=
tl
.
num_programs
(
0
)
head_idx
=
tl
.
program_id
(
1
)
head_idx
=
tl
.
program_id
(
1
)
num_heads
=
tl
.
num_programs
(
1
)
num_heads
=
tl
.
num_programs
(
1
)
prefix_mask
=
token_idx
<
prefill_tokens_with_context
head_arange
=
tl
.
arange
(
0
,
PADDED_HEAD_SIZE
)
head_mask
=
head_arange
<
HEAD_SIZE
# For tokens without context (token_idx >= prefill_tokens_with_context),
# directly copy from suffix_output
if
not
prefix_mask
:
s_lse
=
tl
.
load
(
suffix_lse
+
head_idx
*
num_tokens
+
token_idx
)
if
OUTPUT_LSE
:
tl
.
store
(
output_lse
+
head_idx
*
num_tokens
+
token_idx
,
s_lse
)
s_out
=
tl
.
load
(
suffix_output
+
token_idx
*
num_heads
*
prefix_head_stride
+
head_idx
*
prefix_head_stride
+
head_arange
,
mask
=
head_mask
,
)
tl
.
store
(
output
+
token_idx
*
num_heads
*
output_head_stride
+
head_idx
*
output_head_stride
+
head_arange
,
s_out
,
mask
=
head_mask
,
)
return
# For tokens with context (token_idx < prefill_tokens_with_context),
# perform normal merge operation
p_lse
=
tl
.
load
(
prefix_lse
+
head_idx
*
num_tokens
+
token_idx
)
p_lse
=
tl
.
load
(
prefix_lse
+
head_idx
*
num_tokens
+
token_idx
)
s_lse
=
tl
.
load
(
suffix_lse
+
head_idx
*
num_tokens
+
token_idx
)
s_lse
=
tl
.
load
(
suffix_lse
+
head_idx
*
num_tokens
+
token_idx
)
...
@@ -83,8 +122,6 @@ def merge_attn_states_kernel(
...
@@ -83,8 +122,6 @@ def merge_attn_states_kernel(
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
tl
.
store
(
output_lse
+
head_idx
*
num_tokens
+
token_idx
,
out_lse
)
tl
.
store
(
output_lse
+
head_idx
*
num_tokens
+
token_idx
,
out_lse
)
head_arange
=
tl
.
arange
(
0
,
PADDED_HEAD_SIZE
)
head_mask
=
head_arange
<
HEAD_SIZE
p_out
=
tl
.
load
(
p_out
=
tl
.
load
(
prefix_output
prefix_output
+
token_idx
*
num_heads
*
prefix_head_stride
+
token_idx
*
num_heads
*
prefix_head_stride
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment