Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b62e7e99
Unverified
Commit
b62e7e99
authored
Apr 12, 2025
by
Yineng Zhang
Committed by
GitHub
Apr 12, 2025
Browse files
feat: adapt merge_state (#5337)
parent
7d3b7c87
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
224 additions
and
3 deletions
+224
-3
.github/workflows/pr-test-sgl-kernel.yml
.github/workflows/pr-test-sgl-kernel.yml
+7
-1
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+4
-0
sgl-kernel/csrc/attention/cascade.cu
sgl-kernel/csrc/attention/cascade.cu
+55
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+15
-2
sgl-kernel/tests/test_merge_state.py
sgl-kernel/tests/test_merge_state.py
+138
-0
No files found.
.github/workflows/pr-test-sgl-kernel.yml
View file @
b62e7e99
...
@@ -44,6 +44,12 @@ jobs:
...
@@ -44,6 +44,12 @@ jobs:
cuda-version
:
'
12.8'
cuda-version
:
'
12.8'
name
:
Build Wheel (CUDA ${{ matrix.cuda-version }})
name
:
Build Wheel (CUDA ${{ matrix.cuda-version }})
steps
:
steps
:
-
name
:
Skip unnecessary builds on push to main
if
:
github.event_name == 'push' && (matrix.cuda-version == '11.8' || matrix.cuda-version == '12.8')
run
:
|
echo "Skipping CUDA ${{ matrix.cuda-version }} build on push to main"
exit 0
-
name
:
Cleanup
-
name
:
Cleanup
run
:
|
run
:
|
sudo rm -rf $GITHUB_WORKSPACE/* || true
sudo rm -rf $GITHUB_WORKSPACE/* || true
...
@@ -87,7 +93,7 @@ jobs:
...
@@ -87,7 +93,7 @@ jobs:
-
name
:
Install
-
name
:
Install
run
:
|
run
:
|
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
pip3 install torch==2.5.1 && pip3 install pytest
&& pip3 install vllm==0.7.2
pip3 install torch==2.5.1 && pip3 install pytest
pip3 uninstall sgl-kernel -y || true
pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel
pip3 list | grep sgl-kernel
...
...
sgl-kernel/CMakeLists.txt
View file @
b62e7e99
...
@@ -25,6 +25,8 @@ find_package(Torch REQUIRED)
...
@@ -25,6 +25,8 @@ find_package(Torch REQUIRED)
# clean Torch Flag
# clean Torch Flag
clear_cuda_arches
(
CMAKE_FLAG
)
clear_cuda_arches
(
CMAKE_FLAG
)
set_property
(
GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON
)
include
(
FetchContent
)
include
(
FetchContent
)
# cutlass
# cutlass
...
@@ -104,6 +106,7 @@ set(SGL_KERNEL_CUDA_FLAGS
...
@@ -104,6 +106,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"--expt-relaxed-constexpr"
"--expt-relaxed-constexpr"
"-Xcompiler=-Wconversion"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
"-Xcompiler=-fno-strict-aliasing"
"--threads=16"
)
)
option
(
SGL_KERNEL_ENABLE_SM100A
"Enable SM100A"
OFF
)
option
(
SGL_KERNEL_ENABLE_SM100A
"Enable SM100A"
OFF
)
...
@@ -160,6 +163,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
...
@@ -160,6 +163,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
set
(
SOURCES
set
(
SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
"csrc/elementwise/activation.cu"
...
...
sgl-kernel/csrc/attention/cascade.cu
0 → 100644
View file @
b62e7e99
// Adapted from
// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <flashinfer/attention/cascade.cuh>
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
void
merge_state
(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
)
{
CHECK_INPUT
(
v_a
);
CHECK_INPUT
(
s_a
);
CHECK_INPUT
(
v_b
);
CHECK_INPUT
(
s_b
);
auto
device
=
v_a
.
device
();
CHECK_EQ
(
s_a
.
device
(),
device
);
CHECK_EQ
(
v_b
.
device
(),
device
);
CHECK_EQ
(
s_b
.
device
(),
device
);
CHECK_DIM
(
3
,
v_a
);
CHECK_DIM
(
2
,
s_a
);
CHECK_DIM
(
3
,
v_b
);
CHECK_DIM
(
2
,
s_b
);
CHECK_SHAPE
(
v_a
,
v_b
);
CHECK_SHAPE
(
s_a
,
s_b
);
CHECK_EQ
(
v_a
.
size
(
0
),
s_a
.
size
(
0
));
CHECK_EQ
(
v_a
.
size
(
1
),
s_b
.
size
(
1
));
unsigned
int
seq_len
=
v_a
.
size
(
0
);
unsigned
int
num_heads
=
v_a
.
size
(
1
);
unsigned
int
head_dim
=
v_a
.
size
(
2
);
const
c10
::
cuda
::
OptionalCUDAGuard
device_guard
(
v_a
.
device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
bool
success
=
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
v_a
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
MergeState
(
static_cast
<
c_type
*>
(
v_a
.
data_ptr
()),
static_cast
<
float
*>
(
s_a
.
data_ptr
()),
static_cast
<
c_type
*>
(
v_b
.
data_ptr
()),
static_cast
<
float
*>
(
s_b
.
data_ptr
()),
static_cast
<
c_type
*>
(
v_merged
.
data_ptr
()),
static_cast
<
float
*>
(
s_merged
.
data_ptr
()),
seq_len
,
num_heads
,
head_dim
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"MergeState kernel launch failed: "
,
cudaGetErrorString
(
status
));
return
true
;
});
TORCH_CHECK
(
success
,
"MergeState kernel launch failed: unsupported data type"
);
}
sgl-kernel/csrc/common_extension.cc
View file @
b62e7e99
...
@@ -45,6 +45,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -45,6 +45,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()"
);
"new_kv) -> ()"
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
impl
(
"lightning_attention_decode"
,
torch
::
kCUDA
,
&
lightning_attention_decode
);
m
.
def
(
"merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state"
,
torch
::
kCUDA
,
&
merge_state
);
m
.
def
(
m
.
def
(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()"
);
"page_table, Tensor workspace) -> ()"
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
b62e7e99
...
@@ -87,6 +87,8 @@ void lightning_attention_decode(
...
@@ -87,6 +87,8 @@ void lightning_attention_decode(
const
torch
::
Tensor
&
slope
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
torch
::
Tensor
output
,
torch
::
Tensor
new_kv
);
torch
::
Tensor
new_kv
);
void
merge_state
(
at
::
Tensor
v_a
,
at
::
Tensor
s_a
,
at
::
Tensor
v_b
,
at
::
Tensor
s_b
,
at
::
Tensor
v_merged
,
at
::
Tensor
s_merged
);
void
cutlass_mla_decode
(
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
torch
::
Tensor
const
&
q_nope_and_q_pe
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
b62e7e99
...
@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
...
@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
cutlass_mla_decode
,
cutlass_mla_decode
,
cutlass_mla_get_workspace_size
,
cutlass_mla_get_workspace_size
,
lightning_attention_decode
,
lightning_attention_decode
,
merge_state
,
)
)
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
...
...
sgl-kernel/python/sgl_kernel/attention.py
View file @
b62e7e99
from
typing
import
Tuple
import
torch
import
torch
...
@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
...
@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
)
)
def
merge_state
(
v_a
:
torch
.
Tensor
,
s_a
:
torch
.
Tensor
,
v_b
:
torch
.
Tensor
,
s_b
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
s_a
=
s_a
.
to
(
torch
.
float32
)
s_b
=
s_b
.
to
(
torch
.
float32
)
v_merged
=
torch
.
empty_like
(
v_a
)
s_merged
=
torch
.
empty_like
(
s_a
)
torch
.
ops
.
sgl_kernel
.
merge_state
.
default
(
v_a
,
s_a
,
v_b
,
s_b
,
v_merged
,
s_merged
)
return
v_merged
,
s_merged
def
cutlass_mla_decode
(
def
cutlass_mla_decode
(
q_nope_and_q_pe
:
torch
.
Tensor
,
q_nope_and_q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
...
@@ -54,7 +67,7 @@ def cutlass_mla_decode(
...
@@ -54,7 +67,7 @@ def cutlass_mla_decode(
(
B_q
,
H
,
D_latent
),
device
=
q_nope_and_q_pe
.
device
,
dtype
=
q_nope_and_q_pe
.
dtype
(
B_q
,
H
,
D_latent
),
device
=
q_nope_and_q_pe
.
device
,
dtype
=
q_nope_and_q_pe
.
dtype
)
)
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
(
torch
.
ops
.
sgl_kernel
.
cutlass_mla_decode
.
default
(
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
out
,
q_nope_and_q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
)
)
return
out
return
out
...
@@ -63,6 +76,6 @@ def cutlass_mla_decode(
...
@@ -63,6 +76,6 @@ def cutlass_mla_decode(
def
cutlass_mla_get_workspace_size
(
def
cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
=
0
)
->
int
:
)
->
int
:
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
(
return
torch
.
ops
.
sgl_kernel
.
cutlass_mla_get_workspace_size
.
default
(
max_seq_len
,
num_batches
,
sm_count
max_seq_len
,
num_batches
,
sm_count
)
)
sgl-kernel/tests/test_merge_state.py
0 → 100644
View file @
b62e7e99
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
from
typing
import
List
import
pytest
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
merge_state
def
check_input
(
x
:
torch
.
Tensor
):
assert
x
.
is_cuda
,
f
"
{
str
(
x
)
}
must be a CUDA Tensor"
assert
x
.
is_contiguous
(),
f
"
{
str
(
x
)
}
must be contiguous"
def
check_dim
(
d
,
x
:
torch
.
Tensor
):
assert
x
.
dim
()
==
d
,
f
"
{
str
(
x
)
}
must be a
{
d
}
D tensor"
def
check_shape
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
):
assert
a
.
dim
()
==
b
.
dim
(),
"tensors should have same dim"
for
i
in
range
(
a
.
dim
()):
assert
a
.
size
(
i
)
==
b
.
size
(
i
),
f
"tensors shape mismatch,
{
a
.
size
()
}
and
{
b
.
size
()
}
"
def
check_device
(
tensors
:
List
[
torch
.
Tensor
]):
device
=
tensors
[
0
].
device
for
t
in
tensors
:
assert
(
t
.
device
==
device
),
f
"All tensors should be on the same device, but got
{
device
}
and
{
t
.
device
}
"
@
triton
.
jit
def
state_merge
(
o
,
m
,
d
,
other_o
,
other_m
,
other_d
):
m_max
=
tl
.
maximum
(
m
,
other_m
)
d
=
d
*
tl
.
exp2
(
m
-
m_max
)
+
other_d
*
tl
.
exp2
(
other_m
-
m_max
)
o
=
o
*
tl
.
exp2
(
m
-
m_max
)
+
other_o
*
tl
.
exp2
(
other_m
-
m_max
)
return
o
,
m_max
,
d
@
triton
.
jit
def
state_normalize
(
o
,
m
,
d
):
o
=
o
/
d
return
o
,
m
,
d
@
triton
.
jit
def
state_get_lse
(
o
,
m
,
d
):
return
m
+
tl
.
log2
(
d
)
@
triton
.
jit
def
merge_state_kernel
(
v_a_ptr
,
s_a_ptr
,
v_b_ptr
,
s_b_ptr
,
v_merged_ptr
,
s_merged_ptr
,
num_heads
,
head_dim
,
bdx
:
tl
.
constexpr
,
bdy
:
tl
.
constexpr
,
):
pos
=
tl
.
program_id
(
axis
=
0
)
for
tx
in
tl
.
range
(
bdx
):
for
head_idx
in
tl
.
range
(
bdy
):
s_a_val
=
tl
.
load
(
s_a_ptr
+
pos
*
num_heads
+
head_idx
)
s_b_val
=
tl
.
load
(
s_b_ptr
+
pos
*
num_heads
+
head_idx
)
offsets
=
(
pos
*
num_heads
+
head_idx
)
*
head_dim
+
tx
v_a
=
tl
.
load
(
v_a_ptr
+
offsets
)
v_b
=
tl
.
load
(
v_b_ptr
+
offsets
)
v_merged
,
s_max
,
d
=
state_merge
(
o
=
v_a
,
m
=
s_a_val
,
d
=
1
,
other_o
=
v_b
,
other_m
=
s_b_val
,
other_d
=
1
)
v_merged
,
s_max
,
d
=
state_normalize
(
v_merged
,
s_max
,
d
)
v_merged_offset
=
(
pos
*
num_heads
+
head_idx
)
*
head_dim
+
tx
tl
.
store
(
v_merged_ptr
+
v_merged_offset
,
v_merged
)
if
s_merged_ptr
:
tl
.
store
(
s_merged_ptr
+
pos
*
num_heads
+
head_idx
,
tl
.
log2
(
d
)
+
s_max
,
)
def
merge_state_triton
(
v_a
:
torch
.
Tensor
,
s_a
:
torch
.
Tensor
,
v_b
:
torch
.
Tensor
,
s_b
:
torch
.
Tensor
):
check_input
(
v_a
)
check_input
(
s_a
)
check_input
(
v_b
)
check_input
(
s_b
)
check_device
([
v_a
,
s_a
,
v_b
,
s_b
])
check_dim
(
3
,
v_a
)
check_dim
(
2
,
s_a
)
check_dim
(
3
,
v_b
)
check_dim
(
2
,
s_b
)
check_shape
(
v_a
,
v_b
)
check_shape
(
s_a
,
s_b
)
assert
v_a
.
size
(
0
)
==
s_a
.
size
(
0
)
assert
v_a
.
size
(
1
)
==
s_b
.
size
(
1
)
s_a
=
s_a
.
to
(
torch
.
float32
)
s_b
=
s_b
.
to
(
torch
.
float32
)
seq_len
=
v_a
.
size
(
0
)
num_heads
=
v_a
.
size
(
1
)
head_dim
=
v_a
.
size
(
2
)
v_merged
=
torch
.
empty_like
(
v_a
).
to
(
s_a
.
device
)
s_merged
=
torch
.
empty
((
seq_len
,
num_heads
)).
to
(
s_a
.
device
)
bdx
=
head_dim
bdy
=
num_heads
merge_state_kernel
[
lambda
meta
:
(
seq_len
,)](
v_a
,
s_a
,
v_b
,
s_b
,
v_merged
,
s_merged
,
num_heads
,
head_dim
,
bdx
=
bdx
,
bdy
=
bdy
)
return
v_merged
,
s_merged
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
128
])
def
test_merge_state
(
seq_len
,
num_heads
,
head_dim
):
va
=
torch
.
randn
(
seq_len
,
num_heads
,
head_dim
).
half
().
to
(
"cuda:0"
)
sa
=
torch
.
randn
(
seq_len
,
num_heads
,
dtype
=
torch
.
float32
).
to
(
"cuda:0"
)
vb
=
torch
.
randn
(
seq_len
,
num_heads
,
head_dim
).
half
().
to
(
"cuda:0"
)
sb
=
torch
.
randn
(
seq_len
,
num_heads
,
dtype
=
torch
.
float32
).
to
(
"cuda:0"
)
v_merged
,
s_merged
=
merge_state_triton
(
va
,
sa
,
vb
,
sb
)
v_merged_std
,
s_merged_std
=
merge_state
(
va
,
sa
,
vb
,
sb
)
assert
torch
.
allclose
(
v_merged
,
v_merged_std
,
atol
=
1e-2
)
assert
torch
.
allclose
(
s_merged
,
s_merged_std
,
atol
=
1e-2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment