Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9528f6d
Unverified
Commit
e9528f6d
authored
Apr 11, 2025
by
DefTruth
Committed by
GitHub
Apr 11, 2025
Browse files
[Kernel] support merge_attn_states CUDA kernel, 3x speedup (#16173)
Signed-off-by:
DefTruth
<
qiustudent_r@163.com
>
parent
51baa9c3
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
519 additions
and
4 deletions
+519
-4
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+173
-0
csrc/ops.h
csrc/ops.h
+9
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+15
-0
tests/kernels/test_merge_attn_states.py
tests/kernels/test_merge_attn_states.py
+265
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+11
-0
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-2
vllm/attention/ops/merge_attn_states.py
vllm/attention/ops/merge_attn_states.py
+42
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
No files found.
CMakeLists.txt
View file @
e9528f6d
...
@@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
...
@@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/cache_kernels.cu"
"csrc/cache_kernels.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_kernels.cu"
...
...
csrc/attention/merge_attn_states.cu
0 → 100644
View file @
e9528f6d
#include <optional>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
namespace
vllm
{
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template
<
typename
scalar_t
,
const
uint
NUM_THREADS
>
__global__
void
merge_attn_states_kernel
(
scalar_t
*
output
,
float
*
output_lse
,
const
scalar_t
*
prefix_output
,
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
uint
head_size
)
{
using
pack_128b_t
=
uint4
;
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
global_idx
=
blockIdx
.
x
*
NUM_THREADS
+
threadIdx
.
x
;
const
uint
token_head_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
if
(
global_idx
>=
token_head_threads
)
return
;
// global_idx -> token_idx + head_idx + pack_idx
const
uint
token_head_idx
=
global_idx
/
threads_per_head
;
const
uint
pack_idx
=
global_idx
%
threads_per_head
;
const
uint
token_idx
=
token_head_idx
/
num_heads
;
const
uint
head_idx
=
token_head_idx
%
num_heads
;
const
uint
pack_offset
=
pack_idx
*
pack_size
;
// (0~15)*8, etc.
const
uint
head_offset
=
token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
scalar_t
*
prefix_head_ptr
=
prefix_output
+
head_offset
;
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
head_offset
;
scalar_t
*
output_head_ptr
=
output
+
head_offset
;
float
p_lse
=
prefix_lse
[
head_idx
*
num_tokens
+
token_idx
];
float
s_lse
=
suffix_lse
[
head_idx
*
num_tokens
+
token_idx
];
p_lse
=
std
::
isinf
(
p_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
p_lse
;
s_lse
=
std
::
isinf
(
s_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
s_lse
;
const
float
max_lse
=
fmaxf
(
p_lse
,
s_lse
);
p_lse
=
p_lse
-
max_lse
;
s_lse
=
s_lse
-
max_lse
;
const
float
p_se
=
expf
(
p_lse
);
const
float
s_se
=
expf
(
s_lse
);
const
float
out_se
=
p_se
+
s_se
;
const
float
p_scale
=
p_se
/
out_se
;
const
float
s_scale
=
s_se
/
out_se
;
if
(
pack_offset
<
head_size
)
{
// Pack 128b load
pack_128b_t
p_out_pack
=
reinterpret_cast
<
const
pack_128b_t
*>
(
prefix_head_ptr
)[
pack_offset
/
pack_size
];
pack_128b_t
s_out_pack
=
reinterpret_cast
<
const
pack_128b_t
*>
(
suffix_head_ptr
)[
pack_offset
/
pack_size
];
pack_128b_t
o_out_pack
;
#pragma unroll
for
(
uint
i
=
0
;
i
<
pack_size
;
++
i
)
{
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const
float
p_out_f
=
vllm
::
to_float
(
reinterpret_cast
<
const
scalar_t
*>
(
&
p_out_pack
)[
i
]);
const
float
s_out_f
=
vllm
::
to_float
(
reinterpret_cast
<
const
scalar_t
*>
(
&
s_out_pack
)[
i
]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const
float
o_out_f
=
p_out_f
*
p_scale
+
(
s_out_f
*
s_scale
);
// float -> half(uint16_t), bfloat16, float.
vllm
::
from_float
(
reinterpret_cast
<
scalar_t
*>
(
&
o_out_pack
)[
i
],
o_out_f
);
}
// Pack 128b storage
reinterpret_cast
<
pack_128b_t
*>
(
output_head_ptr
)[
pack_offset
/
pack_size
]
=
o_out_pack
;
}
// We only need to write to output_lse once per head.
if
(
output_lse
!=
nullptr
&&
pack_idx
==
0
)
{
float
out_lse
=
logf
(
out_se
)
+
max_lse
;
output_lse
[
head_idx
*
num_tokens
+
token_idx
]
=
out_lse
;
}
}
}
// namespace vllm
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(uint16_t); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
* states.
*/
template
<
typename
scalar_t
>
void
merge_attn_states_launcher
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
)
{
constexpr
uint
NUM_THREADS
=
128
;
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_heads
=
output
.
size
(
1
);
const
uint
head_size
=
output
.
size
(
2
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
"headsize must be multiple of pack_size:"
,
pack_size
);
float
*
output_lse_ptr
=
nullptr
;
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
}
// process one pack elements per thread. float -> 4, half/bf16 -> 8
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
total_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
dim3
block
(
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
}
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
)
{
DISPATCH_BY_SCALAR_DTYPE
(
output
.
dtype
(),
CALL_MERGE_ATTN_STATES_LAUNCHER
);
}
csrc/ops.h
View file @
e9528f6d
...
@@ -52,6 +52,15 @@ void paged_attention_v2(
...
@@ -52,6 +52,15 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
const
int64_t
blocksparse_head_sliding_step
);
#ifndef USE_ROCM
void
merge_attn_states
(
torch
::
Tensor
&
output
,
std
::
optional
<
torch
::
Tensor
>
output_lse
,
const
torch
::
Tensor
&
prefix_output
,
const
torch
::
Tensor
&
prefix_lse
,
const
torch
::
Tensor
&
suffix_output
,
const
torch
::
Tensor
&
suffix_lse
);
#endif
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
double
epsilon
);
...
...
csrc/torch_bindings.cpp
View file @
e9528f6d
...
@@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -64,6 +64,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
ops
.
impl
(
"paged_attention_v2"
,
torch
::
kCUDA
,
&
paged_attention_v2
);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops
.
def
(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()"
);
ops
.
impl
(
"merge_attn_states"
,
torch
::
kCUDA
,
&
merge_attn_states
);
#endif
// Activation ops
// Activation ops
// Activation function used in SwiGLU.
// Activation function used in SwiGLU.
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
ops
.
def
(
"silu_and_mul(Tensor! out, Tensor input) -> ()"
);
...
...
tests/kernels/test_merge_attn_states.py
0 → 100644
View file @
e9528f6d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
pytest
import
torch
from
vllm._custom_ops
import
merge_attn_states
as
merge_attn_states_cuda
from
vllm.attention.ops.triton_merge_attn_states
import
(
merge_attn_states
as
merge_attn_states_triton
)
from
vllm.platforms
import
current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def
merge_attn_states_torch
(
output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse
:
torch
.
Tensor
,
# [NUM_HEADS, NUM_TOKENS]
suffix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse
:
torch
.
Tensor
,
# [NUM_HEADS, NUM_TOKENS]
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
# [NUM_HEADS, NUM_TOKENS]
):
p_lse
=
prefix_lse
s_lse
=
suffix_lse
# inf -> -inf
p_lse
[
p_lse
==
torch
.
inf
]
=
-
torch
.
inf
s_lse
[
s_lse
==
torch
.
inf
]
=
-
torch
.
inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse
=
torch
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
p_lse_exp
=
torch
.
exp
(
p_lse
)
s_lse_exp
=
torch
.
exp
(
s_lse
)
out_se
=
(
p_lse_exp
+
s_lse_exp
)
if
output_lse
is
not
None
:
output_lse
=
torch
.
log
(
out_se
)
+
max_lse
p_scale
=
p_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
s_scale
=
s_lse_exp
/
out_se
# [NUM_HEADS, NUM_TOKENS]
p_scale
=
torch
.
transpose
(
p_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
s_scale
=
torch
.
transpose
(
s_scale
,
0
,
1
).
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
output
=
prefix_output
*
p_scale
+
suffix_output
*
s_scale
return
output
,
output_lse
NUM_BATCH_TOKENS
=
[
256
,
512
,
613
,
1024
,
1536
,
4096
]
NUM_QUERY_HEADS
=
[
4
,
8
,
16
,
32
,
48
,
64
]
HEAD_SIZES
=
[
32
,
48
,
64
,
96
,
128
,
256
]
DTYPES
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
all_case_info
:
list
[
tuple
]
=
[]
def
generate_markdown_table
():
global
all_case_info
table_header
=
(
"| tokens | heads | headsize | dtype "
"| device | torch | triton | cuda | speedup |"
)
table_separator
=
"| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
def
shortly_dtype
(
dtype
:
torch
.
dtype
)
->
str
:
return
str
(
dtype
).
removeprefix
(
"torch."
)
def
shortly_device
(
device
:
str
)
->
str
:
return
device
.
removeprefix
(
"NVIDIA"
).
strip
()
print
(
table_header
)
print
(
table_separator
)
for
info
in
all_case_info
:
(
num_tokens
,
num_heads
,
head_size
,
dtype
,
device
,
avg_time_torch_kernel
,
avg_time_triton_kernel
,
avg_time_cuda_kernel
,
performance_improved
)
=
info
dtype
=
shortly_dtype
(
dtype
)
device
=
shortly_device
(
device
)
print
(
f
"|
{
num_tokens
}
|
{
num_heads
}
|
{
head_size
}
"
f
"|
{
dtype
}
|
{
device
}
|
{
avg_time_torch_kernel
:.
5
f
}
ms "
f
"|
{
avg_time_triton_kernel
:.
5
f
}
ms "
f
"|
{
avg_time_cuda_kernel
:.
5
f
}
ms "
f
"|
{
performance_improved
:.
4
f
}
x |"
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_BATCH_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_query_heads"
,
NUM_QUERY_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
DTYPES
)
@
torch
.
inference_mode
()
def
test_merge_attn_states
(
num_tokens
:
int
,
num_query_heads
:
int
,
head_size
:
int
,
output_dtype
:
torch
.
dtype
):
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
'Currently only support compare triton merge_attn_states '
'with custom cuda merge_attn_states kernel'
)
NUM_TOKENS
=
num_tokens
NUM_HEADS
=
num_query_heads
HEAD_SIZE
=
head_size
print
(
f
"
\n
NUM_TOKENS:
{
NUM_TOKENS
}
, NUM_HEADS:
{
NUM_HEADS
}
, "
f
"HEAD_SIZE:
{
HEAD_SIZE
}
, DTYPE:
{
output_dtype
}
, "
f
"Device:
{
current_platform
.
get_device_name
()
}
"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse
=
torch
.
randn
(
NUM_HEADS
,
NUM_TOKENS
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
suffix_lse
=
torch
.
randn
(
NUM_HEADS
,
NUM_TOKENS
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# Generate boolean masks
mask_prefix
=
torch
.
rand
(
NUM_HEADS
,
NUM_TOKENS
)
<
0.1
mask_suffix
=
torch
.
rand
(
NUM_HEADS
,
NUM_TOKENS
)
<
0.1
# Ensure that the same position is not True at the same time
combined_mask
=
torch
.
logical_and
(
mask_prefix
,
mask_suffix
)
mask_prefix
=
torch
.
logical_and
(
mask_prefix
,
~
combined_mask
)
mask_suffix
=
torch
.
logical_and
(
mask_suffix
,
~
combined_mask
)
prefix_lse
[
mask_prefix
]
=
float
(
'inf'
)
suffix_lse
[
mask_suffix
]
=
float
(
'inf'
)
# Other input tensors (need to be initialized but
# no actual calculation needed)
output
=
torch
.
zeros
((
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
output_lse
=
torch
.
zeros
((
NUM_HEADS
,
NUM_TOKENS
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
prefix_output
=
torch
.
randn
((
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
suffix_output
=
torch
.
randn
((
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
warmup_times
=
2
repeat_times
=
20
output_torch
=
output
.
clone
()
output_lse_torch
=
output_lse
.
clone
()
total_time_torch_kernel
=
0
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 0. Run the Torch kernel
prefix_lse_torch
=
prefix_lse
.
clone
()
suffix_lse_torch
=
suffix_lse
.
clone
()
for
_
in
range
(
warmup_times
):
output_torch
,
output_lse_torch
=
merge_attn_states_torch
(
output_torch
,
prefix_output
,
prefix_lse_torch
,
suffix_output
,
suffix_lse_torch
,
output_lse_torch
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
repeat_times
):
start
.
record
()
output_torch
,
output_lse_torch
=
merge_attn_states_torch
(
output_torch
,
prefix_output
,
prefix_lse_torch
,
suffix_output
,
suffix_lse_torch
,
output_lse_torch
)
end
.
record
()
torch
.
cuda
.
synchronize
()
total_time_torch_kernel
+=
start
.
elapsed_time
(
end
)
avg_time_torch_kernel
=
total_time_torch_kernel
/
repeat_times
# 1. Run the Triton kernel
output_ref_triton
=
output
.
clone
()
output_lse_ref_triton
=
output_lse
.
clone
()
total_time_triton_kernel
=
0
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
_
in
range
(
warmup_times
):
merge_attn_states_triton
(
output_ref_triton
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse_ref_triton
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
repeat_times
):
start
.
record
()
merge_attn_states_triton
(
output_ref_triton
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse_ref_triton
)
end
.
record
()
torch
.
cuda
.
synchronize
()
total_time_triton_kernel
+=
start
.
elapsed_time
(
end
)
avg_time_triton_kernel
=
total_time_triton_kernel
/
repeat_times
# 2. Run the CUDA kernel
total_time_cuda_kernel
=
0
output_cuda
=
output
.
clone
()
output_lse_cuda
=
output_lse
.
clone
()
for
_
in
range
(
warmup_times
):
merge_attn_states_cuda
(
output_cuda
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse_cuda
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
repeat_times
):
start
.
record
()
merge_attn_states_cuda
(
output_cuda
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse_cuda
)
end
.
record
()
torch
.
cuda
.
synchronize
()
total_time_cuda_kernel
+=
start
.
elapsed_time
(
end
)
avg_time_cuda_kernel
=
total_time_cuda_kernel
/
repeat_times
# 3. Performance compare
performance_improved
=
avg_time_triton_kernel
/
avg_time_cuda_kernel
print
(
f
" Torch time:
{
avg_time_torch_kernel
:.
6
f
}
ms"
)
print
(
f
"Triton time:
{
avg_time_triton_kernel
:.
6
f
}
ms"
)
print
(
f
" CUDA time:
{
avg_time_cuda_kernel
:.
6
f
}
ms, "
f
"Performance:
{
performance_improved
:.
5
f
}
x"
)
print
(
"-"
*
100
)
# 4. Correctness compare
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol
=
1e-2
if
output_dtype
==
torch
.
bfloat16
else
1e-3
def
diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
max_diff
=
torch
.
max
(
torch
.
abs
(
a
.
float
()
-
b
.
float
()))
return
max_diff
# Use Triton output as reference because we want to replace
# the Triton kernel with custom CUDA kernel for merge attn
# states operation.
output_ref
=
output_ref_triton
output_lse_ref
=
output_lse_ref_triton
torch
.
testing
.
assert_close
(
output_cuda
.
float
(),
output_ref
.
float
(),
atol
=
1e-3
,
rtol
=
rtol
)
print
(
"Output all match, max abs diff:"
)
print
(
f
"(Triton vs Torch) :
{
diff
(
output_torch
,
output_ref
)
}
"
)
print
(
f
" (CUDA vs Torch) :
{
diff
(
output_torch
,
output_cuda
)
}
"
)
print
(
f
" (CUDA vs Triton):
{
diff
(
output_ref
,
output_cuda
)
}
"
)
print
(
"-"
*
100
)
torch
.
testing
.
assert_close
(
output_lse_cuda
.
float
(),
output_lse_ref
.
float
(),
atol
=
1e-3
,
rtol
=
rtol
)
print
(
"Output LSE all match, max abs diff:"
)
print
(
f
"(Triton vs Torch) :
{
diff
(
output_lse_torch
,
output_lse_ref
)
}
"
)
print
(
f
" (CUDA vs Torch) :
{
diff
(
output_lse_torch
,
output_lse_cuda
)
}
"
)
print
(
f
" (CUDA vs Triton):
{
diff
(
output_lse_ref
,
output_lse_cuda
)
}
"
)
print
(
"-"
*
100
)
print
(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print
(
"-"
*
100
)
device
=
current_platform
.
get_device_name
()
all_case_info
.
append
(
(
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
,
output_dtype
,
device
,
avg_time_torch_kernel
,
avg_time_triton_kernel
,
avg_time_cuda_kernel
,
performance_improved
))
if
len
(
all_case_info
)
==
(
len
(
NUM_BATCH_TOKENS
)
*
len
(
HEAD_SIZES
)
*
len
(
NUM_QUERY_HEADS
)
*
len
(
DTYPES
)):
generate_markdown_table
()
vllm/_custom_ops.py
View file @
e9528f6d
...
@@ -138,6 +138,17 @@ def mla_decode_kvcache_cpu(
...
@@ -138,6 +138,17 @@ def mla_decode_kvcache_cpu(
block_tables
,
seq_lens
)
block_tables
,
seq_lens
)
# merge attn states ops
def
merge_attn_states
(
output
:
torch
.
Tensor
,
prefix_output
:
torch
.
Tensor
,
prefix_lse
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
torch
.
ops
.
_C
.
merge_attn_states
(
output
,
output_lse
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
)
# pos encoding ops
# pos encoding ops
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
vllm/attention/backends/mla/common.py
View file @
e9528f6d
...
@@ -204,6 +204,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
...
@@ -204,6 +204,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
...
@@ -217,9 +218,7 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
...
@@ -217,9 +218,7 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
else
:
else
:
merge_attn_states
=
None
triton_attention
=
None
triton_attention
=
None
try
:
try
:
...
...
vllm/attention/ops/merge_attn_states.py
0 → 100644
View file @
e9528f6d
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
from
vllm.platforms
import
current_platform
def
merge_attn_states
(
output
:
torch
.
Tensor
,
prefix_output
:
torch
.
Tensor
,
prefix_lse
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# is not support for FP8 dtype, fallback to use Triton kernel.
def
supported_dtypes
(
o
:
torch
.
Tensor
)
->
bool
:
return
o
.
dtype
in
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# kernel load/store 128b(16 bytes) per memory issue within
# thread. Namely, the headsize(headdim) must be multiple of
# pack_size (float32 -> 4, half/bfloat16 -> 8).
def
supported_headdim
(
o
:
torch
.
Tensor
)
->
bool
:
headdim
=
o
.
shape
[
2
]
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if
o
.
dtype
==
torch
.
float32
:
return
headdim
%
4
==
0
return
headdim
%
8
==
0
if
(
current_platform
.
is_cuda
()
and
supported_dtypes
(
output
)
and
supported_headdim
(
output
)):
from
vllm._custom_ops
import
merge_attn_states
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
)
else
:
from
vllm.attention.ops.triton_merge_attn_states
import
(
merge_attn_states
)
return
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output_lse
)
vllm/v1/attention/backends/flash_attn.py
View file @
e9528f6d
...
@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
...
@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
is_quantized_kv_cache
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
...
vllm/v1/attention/backends/mla/common.py
View file @
e9528f6d
...
@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
...
@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadata
,
MLAAttentionImpl
)
MLAAttentionImpl
)
from
vllm.attention.ops.
triton_
merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment