Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
388e15c0
Unverified
Commit
388e15c0
authored
Apr 15, 2025
by
DefTruth
Committed by
GitHub
Apr 14, 2025
Browse files
kernel: support slightly faster merge_state_v2 cuda kernel (#5381)
parent
11421a3f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
638 additions
and
4 deletions
+638
-4
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/csrc/attention/merge_attn_states.cu
sgl-kernel/csrc/attention/merge_attn_states.cu
+201
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+35
-4
sgl-kernel/tests/test_merge_state_v2.py
sgl-kernel/tests/test_merge_state_v2.py
+396
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
388e15c0
...
@@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
...
@@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
set
(
SOURCES
set
(
SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/cascade.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/activation.cu"
...
...
sgl-kernel/csrc/attention/merge_attn_states.cu
0 → 100644
View file @
388e15c0
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
inline
__device__
float
to_float
(
float
u
)
{
return
u
;
}
inline
__device__
float
to_float
(
half
u
)
{
return
__half2float
(
u
);
}
inline
__device__
float
to_float
(
__nv_bfloat16
u
)
{
return
__bfloat162float
(
u
);
}
inline
__device__
void
from_float
(
float
&
d
,
float
s
)
{
d
=
s
;
}
inline
__device__
void
from_float
(
half
&
d
,
float
s
)
{
d
=
__float2half
(
s
);
}
inline
__device__
void
from_float
(
__nv_bfloat16
&
d
,
float
s
)
{
d
=
__float2bfloat16
(
s
);
}
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template
<
typename
scalar_t
,
const
uint
NUM_THREADS
>
__global__
void
merge_attn_states_kernel
(
scalar_t
*
output
,
float
*
output_lse
,
const
scalar_t
*
prefix_output
,
const
float
*
prefix_lse
,
const
scalar_t
*
suffix_output
,
const
float
*
suffix_lse
,
const
uint
num_tokens
,
const
uint
num_heads
,
const
uint
head_size
)
{
using
pack_128b_t
=
uint4
;
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
global_idx
=
blockIdx
.
x
*
NUM_THREADS
+
threadIdx
.
x
;
const
uint
token_head_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
if
(
global_idx
>=
token_head_threads
)
return
;
// global_idx -> token_idx + head_idx + pack_idx
const
uint
token_head_idx
=
global_idx
/
threads_per_head
;
const
uint
pack_idx
=
global_idx
%
threads_per_head
;
const
uint
token_idx
=
token_head_idx
/
num_heads
;
const
uint
head_idx
=
token_head_idx
%
num_heads
;
const
uint
pack_offset
=
pack_idx
*
pack_size
;
// (0~15)*8, etc.
const
uint
head_offset
=
token_idx
*
num_heads
*
head_size
+
head_idx
*
head_size
;
const
scalar_t
*
prefix_head_ptr
=
prefix_output
+
head_offset
;
const
scalar_t
*
suffix_head_ptr
=
suffix_output
+
head_offset
;
scalar_t
*
output_head_ptr
=
output
+
head_offset
;
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
float
p_lse
=
prefix_lse
[
token_idx
*
num_heads
+
head_idx
];
float
s_lse
=
suffix_lse
[
token_idx
*
num_heads
+
head_idx
];
p_lse
=
std
::
isinf
(
p_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
p_lse
;
s_lse
=
std
::
isinf
(
s_lse
)
?
-
std
::
numeric_limits
<
float
>::
infinity
()
:
s_lse
;
const
float
max_lse
=
fmaxf
(
p_lse
,
s_lse
);
p_lse
=
p_lse
-
max_lse
;
s_lse
=
s_lse
-
max_lse
;
const
float
p_se
=
expf
(
p_lse
);
const
float
s_se
=
expf
(
s_lse
);
const
float
out_se
=
p_se
+
s_se
;
const
float
p_scale
=
p_se
/
out_se
;
const
float
s_scale
=
s_se
/
out_se
;
if
(
pack_offset
<
head_size
)
{
// Pack 128b load
pack_128b_t
p_out_pack
=
reinterpret_cast
<
const
pack_128b_t
*>
(
prefix_head_ptr
)[
pack_offset
/
pack_size
];
pack_128b_t
s_out_pack
=
reinterpret_cast
<
const
pack_128b_t
*>
(
suffix_head_ptr
)[
pack_offset
/
pack_size
];
pack_128b_t
o_out_pack
;
#pragma unroll
for
(
uint
i
=
0
;
i
<
pack_size
;
++
i
)
{
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const
float
p_out_f
=
to_float
(
reinterpret_cast
<
const
scalar_t
*>
(
&
p_out_pack
)[
i
]);
const
float
s_out_f
=
to_float
(
reinterpret_cast
<
const
scalar_t
*>
(
&
s_out_pack
)[
i
]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const
float
o_out_f
=
p_out_f
*
p_scale
+
(
s_out_f
*
s_scale
);
// float -> half(uint16_t), bfloat16, float.
from_float
(
reinterpret_cast
<
scalar_t
*>
(
&
o_out_pack
)[
i
],
o_out_f
);
}
// Pack 128b storage
reinterpret_cast
<
pack_128b_t
*>
(
output_head_ptr
)[
pack_offset
/
pack_size
]
=
o_out_pack
;
}
// We only need to write to output_lse once per head.
if
(
output_lse
!=
nullptr
&&
pack_idx
==
0
)
{
float
out_lse
=
logf
(
out_se
)
+
max_lse
;
output_lse
[
token_idx
*
num_heads
+
head_idx
]
=
out_lse
;
}
}
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(half); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), \
reinterpret_cast<float*>(output_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
num_tokens, \
num_heads, \
head_size); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
* states.
*/
template
<
typename
scalar_t
>
void
merge_attn_states_launcher
(
const
at
::
Tensor
&
prefix_output
,
// [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const
at
::
Tensor
&
prefix_lse
,
// [NUM_TOKENS, NUM_HEADS]
const
at
::
Tensor
&
suffix_output
,
// [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
const
at
::
Tensor
&
suffix_lse
,
// [NUM_TOKENS, NUM_HEADS]
at
::
Tensor
&
output
,
// [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
at
::
Tensor
&
output_lse
// [NUM_TOKENS, NUM_HEADS]
)
{
constexpr
uint
NUM_THREADS
=
128
;
const
uint
num_tokens
=
output
.
size
(
0
);
const
uint
num_heads
=
output
.
size
(
1
);
const
uint
head_size
=
output
.
size
(
2
);
const
uint
pack_size
=
16
/
sizeof
(
scalar_t
);
TORCH_CHECK
(
head_size
%
pack_size
==
0
,
"headsize must be multiple of pack_size:"
,
pack_size
);
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
total_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
dim3
block
(
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
void
merge_state_v2
(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
)
{
// Input tensors must be contiguous
CHECK_INPUT
(
v_a
);
// v_a prefix_output (seq_len, num_heads, head_dim)
CHECK_INPUT
(
s_a
);
// s_a prefix_lse (seq_len, num_heads)
CHECK_INPUT
(
v_b
);
// v_b suffix_output (seq_len, num_heads, head_dim)
CHECK_INPUT
(
s_b
);
// s_b suffix_lse (seq_len, num_heads)
// v_merged output (seq_len, num_heads, head_dim)
// s_merged output_lse (seq_len, num_heads)
auto
device
=
v_a
.
device
();
CHECK_EQ
(
s_a
.
device
(),
device
);
CHECK_EQ
(
v_b
.
device
(),
device
);
CHECK_EQ
(
s_b
.
device
(),
device
);
CHECK_DIM
(
3
,
v_a
);
CHECK_DIM
(
2
,
s_a
);
CHECK_DIM
(
3
,
v_b
);
CHECK_DIM
(
2
,
s_b
);
CHECK_SHAPE
(
v_a
,
v_b
);
CHECK_SHAPE
(
s_a
,
s_b
);
CHECK_EQ
(
v_a
.
size
(
0
),
s_a
.
size
(
0
));
CHECK_EQ
(
v_a
.
size
(
1
),
s_b
.
size
(
1
));
DISPATCH_BY_SCALAR_DTYPE
(
v_merged
.
dtype
(),
CALL_MERGE_ATTN_STATES_LAUNCHER
);
}
sgl-kernel/csrc/common_extension.cc
View file @
388e15c0
...
@@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
def
(
"merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
def
(
"merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state"
,
torch
::
kCUDA
,
&
merge_state
);
m
.
impl
(
"merge_state"
,
torch
::
kCUDA
,
&
merge_state
);
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
m
.
def
(
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()"
);
"page_table, Tensor workspace) -> ()"
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
388e15c0
...
@@ -89,6 +89,8 @@ void lightning_attention_decode(
...
@@ -89,6 +89,8 @@ void lightning_attention_decode(
torch
::
Tensor
new_kv
);
torch
::
Tensor
new_kv
);
void
merge_state
(
void
merge_state
(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
merge_state_v2
(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
388e15c0
...
@@ -16,6 +16,7 @@ from sgl_kernel.attention import (
...
@@ -16,6 +16,7 @@ from sgl_kernel.attention import (
cutlass_mla_get_workspace_size
,
cutlass_mla_get_workspace_size
,
lightning_attention_decode
,
lightning_attention_decode
,
merge_state
,
merge_state
,
merge_state_v2
,
)
)
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
388e15c0
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
...
@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
...
@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
def
merge_state
(
def
merge_state
(
v_a
:
torch
.
Tensor
,
s_a
:
torch
.
Tensor
,
v_b
:
torch
.
Tensor
,
s_b
:
torch
.
Tensor
v_a
:
torch
.
Tensor
,
s_a
:
torch
.
Tensor
,
v_b
:
torch
.
Tensor
,
s_b
:
torch
.
Tensor
,
v_merged
:
Optional
[
torch
.
Tensor
]
=
None
,
s_merged
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
s_a
=
s_a
.
to
(
torch
.
float32
)
s_a
=
s_a
.
to
(
torch
.
float32
)
s_b
=
s_b
.
to
(
torch
.
float32
)
s_b
=
s_b
.
to
(
torch
.
float32
)
v_merged
=
torch
.
empty_like
(
v_a
)
# Avoid creating new tensors if they are already provided
s_merged
=
torch
.
empty_like
(
s_a
)
if
v_merged
is
None
:
v_merged
=
torch
.
empty_like
(
v_a
)
if
s_merged
is
None
:
s_merged
=
torch
.
empty_like
(
s_a
)
torch
.
ops
.
sgl_kernel
.
merge_state
.
default
(
v_a
,
s_a
,
v_b
,
s_b
,
v_merged
,
s_merged
)
torch
.
ops
.
sgl_kernel
.
merge_state
.
default
(
v_a
,
s_a
,
v_b
,
s_b
,
v_merged
,
s_merged
)
return
v_merged
,
s_merged
return
v_merged
,
s_merged
def
merge_state_v2
(
v_a
:
torch
.
Tensor
,
s_a
:
torch
.
Tensor
,
v_b
:
torch
.
Tensor
,
s_b
:
torch
.
Tensor
,
v_merged
:
Optional
[
torch
.
Tensor
]
=
None
,
s_merged
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
s_a
=
s_a
.
to
(
torch
.
float32
)
s_b
=
s_b
.
to
(
torch
.
float32
)
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
# does not support the FP8 data type and non - CUDA devices.
# It may be necessary to fall back to using the Triton kernel.
# Avoid creating new tensors if they are already provided
if
v_merged
is
None
:
v_merged
=
torch
.
empty_like
(
v_a
)
if
s_merged
is
None
:
s_merged
=
torch
.
empty_like
(
s_a
)
torch
.
ops
.
sgl_kernel
.
merge_state_v2
.
default
(
v_a
,
s_a
,
v_b
,
s_b
,
v_merged
,
s_merged
)
return
v_merged
,
s_merged
def
cutlass_mla_decode
(
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
q_nope_and_q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_merge_state_v2.py
0 → 100644
View file @
388e15c0
from
typing
import
Optional
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
merge_state
,
merge_state_v2
@
triton
.
jit
def
merge_state_kernel
(
output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
output_lse
,
# [NUM_TOKENS, NUM_HEADS] s_merged
prefix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
prefix_lse
,
# [NUM_TOKENS, NUM_HEADS] s_a
suffix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
suffix_lse
,
# [NUM_TOKENS, NUM_HEADS] s_b
HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
):
token_idx
=
tl
.
program_id
(
0
)
num_tokens
=
tl
.
num_programs
(
0
)
head_idx
=
tl
.
program_id
(
1
)
num_heads
=
tl
.
num_programs
(
1
)
p_lse
=
tl
.
load
(
prefix_lse
+
token_idx
*
num_heads
+
head_idx
)
s_lse
=
tl
.
load
(
suffix_lse
+
token_idx
*
num_heads
+
head_idx
)
p_lse
=
float
(
"-inf"
)
if
p_lse
==
float
(
"inf"
)
else
p_lse
s_lse
=
float
(
"-inf"
)
if
s_lse
==
float
(
"inf"
)
else
s_lse
max_lse
=
tl
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
out_se
=
tl
.
exp
(
p_lse
)
+
tl
.
exp
(
s_lse
)
if
OUTPUT_LSE
:
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
tl
.
store
(
output_lse
+
token_idx
*
num_heads
+
head_idx
,
out_lse
)
head_arange
=
tl
.
arange
(
0
,
PADDED_HEAD_SIZE
)
head_mask
=
head_arange
<
HEAD_SIZE
p_out
=
tl
.
load
(
prefix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
mask
=
head_mask
,
)
s_out
=
tl
.
load
(
suffix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
mask
=
head_mask
,
)
p_scale
=
tl
.
exp
(
p_lse
)
/
out_se
s_scale
=
tl
.
exp
(
s_lse
)
/
out_se
out
=
p_out
*
p_scale
+
s_out
*
s_scale
tl
.
store
(
output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
out
,
mask
=
head_mask
,
)
def
merge_state_triton
(
prefix_output
:
torch
.
Tensor
,
prefix_lse
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
num_tokens
=
output
.
shape
[
0
]
num_query_heads
=
output
.
shape
[
1
]
head_size
=
output
.
shape
[
2
]
padded_head_size
=
triton
.
next_power_of_2
(
head_size
)
# Avoid creating new tensors if they are already provided
if
output
is
None
:
output
=
torch
.
empty_like
(
prefix_output
)
if
output_lse
is
None
:
output_lse
=
torch
.
empty_like
(
prefix_lse
)
merge_state_kernel
[(
num_tokens
,
num_query_heads
)](
output
,
output_lse
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
head_size
,
padded_head_size
,
output_lse
is
not
None
,
)
return
output
,
output_lse
# Naive PyTorch Implements of Merge Attn States
def
merge_state_torch
(
prefix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS]
suffix_output
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse
:
torch
.
Tensor
,
# [NUM_TOKENS, NUM_HEADS]
output
:
Optional
[
torch
.
Tensor
]
=
None
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
# [NUM_TOKENS, NUM_HEADS]
):
# Avoid creating new tensors if they are already provided
if
output
is
None
:
output
=
torch
.
empty_like
(
prefix_output
)
if
output_lse
is
None
:
output_lse
=
torch
.
empty_like
(
prefix_lse
)
p_lse
=
prefix_lse
s_lse
=
suffix_lse
# inf -> -inf
p_lse
[
p_lse
==
torch
.
inf
]
=
-
torch
.
inf
s_lse
[
s_lse
==
torch
.
inf
]
=
-
torch
.
inf
# max_lse [NUM_HEADS, NUM_TOKENS]
max_lse
=
torch
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
p_lse_exp
=
torch
.
exp
(
p_lse
)
s_lse_exp
=
torch
.
exp
(
s_lse
)
out_se
=
p_lse_exp
+
s_lse_exp
if
output_lse
is
not
None
:
output_lse
=
torch
.
log
(
out_se
)
+
max_lse
p_scale
=
p_lse_exp
/
out_se
s_scale
=
s_lse_exp
/
out_se
p_scale
=
p_scale
.
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
s_scale
=
s_scale
.
unsqueeze
(
2
)
# [NUM_TOKENS, NUM_HEADS, 1]
output
=
prefix_output
*
p_scale
+
suffix_output
*
s_scale
return
output
,
output_lse
NUM_BATCH_TOKENS
=
[
256
,
512
,
613
,
1024
,
1536
]
NUM_QUERY_HEADS
=
[
8
,
16
,
32
]
HEAD_SIZES
=
[
32
,
48
,
64
,
128
,
256
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
]
all_case_info
:
list
[
tuple
]
=
[]
def
generate_markdown_table
():
global
all_case_info
table_header
=
(
"| tokens | heads | headsize | dtype "
"| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|"
)
table_separator
=
(
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |"
)
def
shortly_dtype
(
dtype
:
torch
.
dtype
)
->
str
:
return
str
(
dtype
).
removeprefix
(
"torch."
)
def
shortly_device
(
device
:
str
)
->
str
:
return
device
.
removeprefix
(
"NVIDIA"
).
strip
()
print
(
table_header
)
print
(
table_separator
)
for
info
in
all_case_info
:
(
num_tokens
,
num_heads
,
head_size
,
dtype
,
device
,
time_torch
,
time_triton
,
time_v1
,
time_v2
,
)
=
info
dtype
=
shortly_dtype
(
dtype
)
device
=
shortly_device
(
device
)
improved_triton
=
time_triton
/
time_v2
improved_v1
=
time_v1
/
time_v2
print
(
f
"|
{
num_tokens
}
|
{
num_heads
}
|
{
head_size
}
"
f
"|
{
dtype
}
|
{
device
}
|
{
time_torch
:.
4
f
}
ms "
f
"|
{
time_triton
:.
4
f
}
ms "
f
"|
{
time_v1
:.
4
f
}
ms "
f
"|
{
time_v2
:.
4
f
}
ms "
f
"|
{
improved_triton
:.
4
f
}
x "
f
"|
{
improved_v1
:.
4
f
}
x |"
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_BATCH_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_query_heads"
,
NUM_QUERY_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"output_dtype"
,
DTYPES
)
@
torch
.
inference_mode
()
def
test_merge_attn_states
(
num_tokens
:
int
,
num_query_heads
:
int
,
head_size
:
int
,
output_dtype
:
torch
.
dtype
):
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"Currently only support compare triton merge_attn_states "
"with custom cuda merge_attn_states kernel"
)
NUM_TOKENS
=
num_tokens
NUM_HEADS
=
num_query_heads
HEAD_SIZE
=
head_size
print
(
f
"
\n
NUM_TOKENS:
{
NUM_TOKENS
}
, NUM_HEADS:
{
NUM_HEADS
}
, "
f
"HEAD_SIZE:
{
HEAD_SIZE
}
, DTYPE:
{
output_dtype
}
, "
f
"Device:
{
torch
.
cuda
.
get_device_name
()
}
"
)
# prefix_lse and suffix_lse contain inf and normal values
prefix_lse
=
torch
.
randn
(
NUM_TOKENS
,
NUM_HEADS
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
suffix_lse
=
torch
.
randn
(
NUM_TOKENS
,
NUM_HEADS
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# Generate boolean masks
mask_prefix
=
torch
.
rand
(
NUM_TOKENS
,
NUM_HEADS
)
<
0.1
mask_suffix
=
torch
.
rand
(
NUM_TOKENS
,
NUM_HEADS
)
<
0.1
# Ensure that the same position is not True at the same time
combined_mask
=
torch
.
logical_and
(
mask_prefix
,
mask_suffix
)
mask_prefix
=
torch
.
logical_and
(
mask_prefix
,
~
combined_mask
)
mask_suffix
=
torch
.
logical_and
(
mask_suffix
,
~
combined_mask
)
prefix_lse
[
mask_prefix
]
=
float
(
"inf"
)
suffix_lse
[
mask_suffix
]
=
float
(
"inf"
)
# Other input tensors (need to be initialized but
# no actual calculation needed)
output
=
torch
.
zeros
(
(
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
output_lse
=
torch
.
zeros
(
(
NUM_TOKENS
,
NUM_HEADS
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
prefix_output
=
torch
.
randn
(
(
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
suffix_output
=
torch
.
randn
(
(
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
),
dtype
=
output_dtype
,
device
=
"cuda"
)
warmup_times
=
2
repeat_times
=
20
def
perf_kernel_fn
(
output_fn
:
torch
.
Tensor
,
output_lse_fn
:
torch
.
Tensor
,
kernel_fn
:
callable
,
fn_type
:
str
=
"torch"
,
):
# Avoid inplace inf -> -inf, we have to use prefix_lse
# and suffix_lse for other kernel.
if
fn_type
==
"torch"
:
prefix_lse_
=
prefix_lse
.
clone
()
suffix_lse_
=
suffix_lse
.
clone
()
else
:
prefix_lse_
=
prefix_lse
suffix_lse_
=
suffix_lse
if
fn_type
==
"cuda_v1"
:
# merge_state v1 kernel not support float32
if
output_dtype
not
in
(
torch
.
half
,
torch
.
bfloat16
):
return
0
,
output_fn
,
output_lse_fn
total_time
=
0
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
try
:
for
_
in
range
(
warmup_times
):
output_fn
,
output_lse_fn
=
kernel_fn
(
prefix_output
,
prefix_lse_
,
suffix_output
,
suffix_lse_
,
output_fn
,
output_lse_fn
,
)
torch
.
cuda
.
synchronize
()
for
_
in
range
(
repeat_times
):
start
.
record
()
output_fn
,
output_lse_fn
=
kernel_fn
(
prefix_output
,
prefix_lse_
,
suffix_output
,
suffix_lse_
,
output_fn
,
output_lse_fn
,
)
end
.
record
()
torch
.
cuda
.
synchronize
()
total_time
+=
start
.
elapsed_time
(
end
)
avg_time
=
total_time
/
repeat_times
return
avg_time
,
output_fn
,
output_lse_fn
except
Exception
as
e
:
return
0
,
output_fn
,
output_lse_fn
# 0. Run the Torch kernel
output_torch
=
output
.
clone
()
output_lse_torch
=
output_lse
.
clone
()
time_torch
,
output_torch
,
output_lse_torch
=
perf_kernel_fn
(
output_torch
,
output_lse_torch
,
merge_state_torch
,
fn_type
=
"torch"
)
# 1. Run the Triton kernel
output_ref_triton
=
output
.
clone
()
output_lse_ref_triton
=
output_lse
.
clone
()
time_triton
,
output_ref_triton
,
output_lse_ref_triton
=
perf_kernel_fn
(
output_ref_triton
,
output_lse_ref_triton
,
merge_state_triton
,
fn_type
=
"triton"
,
)
# 2. Run the merge_state V1 kernel
output_v1
=
output
.
clone
()
output_lse_v1
=
output_lse
.
clone
()
time_v1
,
output_v1
,
output_lse_v1
=
perf_kernel_fn
(
output_v1
,
output_lse_v1
,
merge_state
,
fn_type
=
"cuda_v1"
)
# 3. Run the merge_state V2 kernel
output_v2
=
output
.
clone
()
output_lse_v2
=
output_lse
.
clone
()
time_v2
,
output_v2
,
output_lse_v2
=
perf_kernel_fn
(
output_v2
,
output_lse_v2
,
merge_state_v2
,
fn_type
=
"cuda_v2"
)
# 4. Performance compare
improved
=
time_triton
/
time_v2
print
(
f
" Torch time:
{
time_torch
:.
6
f
}
ms"
)
print
(
f
" Triton time:
{
time_triton
:.
6
f
}
ms"
)
print
(
f
"CUDA v1 time:
{
time_v1
:.
6
f
}
ms"
)
print
(
f
"CUDA v2 time:
{
time_v2
:.
6
f
}
ms, Performance:
{
improved
:.
5
f
}
x"
)
print
(
"-"
*
100
)
# 5. Correctness compare
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol
=
1e-2
if
output_dtype
==
torch
.
bfloat16
else
1e-3
def
diff
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
max_diff
=
torch
.
max
(
torch
.
abs
(
a
.
float
()
-
b
.
float
()))
return
max_diff
# Use Triton output as reference because we want to replace
# the Triton kernel with custom CUDA kernel for merge attn
# states operation.
output_ref
=
output_ref_triton
output_lse_ref
=
output_lse_ref_triton
torch
.
testing
.
assert_close
(
output_v2
.
float
(),
output_ref
.
float
(),
atol
=
1e-3
,
rtol
=
rtol
)
print
(
"Output all match, max abs diff:"
)
print
(
f
"(Triton vs Torch) :
{
diff
(
output_torch
,
output_ref
)
}
"
)
print
(
f
"(CUDA v2 vs Torch) :
{
diff
(
output_torch
,
output_v2
)
}
"
)
print
(
f
"(CUDA v2 vs Triton):
{
diff
(
output_ref
,
output_v2
)
}
"
)
print
(
"-"
*
100
)
torch
.
testing
.
assert_close
(
output_lse_v2
.
float
(),
output_lse_ref
.
float
(),
atol
=
1e-3
,
rtol
=
rtol
)
print
(
"Output LSE all match, max abs diff:"
)
print
(
f
"(Triton vs Torch) :
{
diff
(
output_lse_torch
,
output_lse_ref
)
}
"
)
print
(
f
"(CUDA v2 vs Torch) :
{
diff
(
output_lse_torch
,
output_lse_v2
)
}
"
)
print
(
f
"(CUDA v2 vs Triton):
{
diff
(
output_lse_ref
,
output_lse_v2
)
}
"
)
print
(
"-"
*
100
)
print
(
"All output values test passed! All inf values "
"are correctly replaced with -inf."
)
print
(
"-"
*
100
)
device
=
torch
.
cuda
.
get_device_name
()
all_case_info
.
append
(
(
NUM_TOKENS
,
NUM_HEADS
,
HEAD_SIZE
,
output_dtype
,
device
,
time_torch
,
time_triton
,
time_v1
,
time_v2
,
)
)
if
len
(
all_case_info
)
==
(
len
(
NUM_BATCH_TOKENS
)
*
len
(
HEAD_SIZES
)
*
len
(
NUM_QUERY_HEADS
)
*
len
(
DTYPES
)
):
generate_markdown_table
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment