Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d297cda2
Commit
d297cda2
authored
Nov 10, 2025
by
lizhigong
Browse files
add merge_state_v2 in sgl_kernel
parent
63c8d8d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
1 deletion
+22
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-0
sgl-kernel/csrc/attention/merge_attn_states.cu
sgl-kernel/csrc/attention/merge_attn_states.cu
+14
-1
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+6
-0
sgl-kernel/setup_hip.py
sgl-kernel/setup_hip.py
+1
-0
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
d297cda2
...
@@ -185,6 +185,7 @@ elif _is_hip:
...
@@ -185,6 +185,7 @@ elif _is_hip:
from
sglang.srt.layers.quantization.awq_triton
import
(
from
sglang.srt.layers.quantization.awq_triton
import
(
awq_dequantize_triton
as
awq_dequantize
,
awq_dequantize_triton
as
awq_dequantize
,
)
)
from
sgl_kernel
import
merge_state_v2
elif
_is_npu
:
elif
_is_npu
:
import
custom_ops
# noqa: F401
import
custom_ops
# noqa: F401
import
sgl_kernel_npu
# noqa: F401
import
sgl_kernel_npu
# noqa: F401
...
...
sgl-kernel/csrc/attention/merge_attn_states.cu
View file @
d297cda2
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include <algorithm>
#include <algorithm>
#include <optional>
#include <optional>
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils
_rocm
.h"
// Helper functions to convert between different data types
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
// (float, half, bfloat16) for the merge attention states kernel.
...
@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
...
@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d
=
__float2bfloat16
(
s
);
d
=
__float2bfloat16
(
s
);
}
}
inline
void
check_shape
(
const
at
::
Tensor
&
a
,
const
at
::
Tensor
&
b
,
const
char
*
a_name
,
const
char
*
b_name
)
{
TORCH_CHECK
(
a
.
dim
()
==
b
.
dim
(),
a_name
,
".dim() != "
,
b_name
,
".dim(). "
,
a
.
dim
(),
" vs "
,
b
.
dim
());
for
(
int
i
=
0
;
i
<
a
.
dim
();
++
i
)
{
TORCH_CHECK
(
a
.
size
(
i
)
==
b
.
size
(
i
),
a_name
,
".size("
,
i
,
") != "
,
b_name
,
".size("
,
i
,
")"
);
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template
<
typename
scalar_t
,
const
uint
NUM_THREADS
>
template
<
typename
scalar_t
,
const
uint
NUM_THREADS
>
__global__
void
merge_attn_states_kernel
(
__global__
void
merge_attn_states_kernel
(
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
d297cda2
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
m
.
def
(
"gelu_quick(Tensor! out, Tensor input) -> ()"
);
m
.
impl
(
"gelu_quick"
,
torch
::
kCUDA
,
&
gelu_quick
);
m
.
impl
(
"gelu_quick"
,
torch
::
kCUDA
,
&
gelu_quick
);
/*
* From csrc/attention
*/
m
.
def
(
"merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"
);
m
.
impl
(
"merge_state_v2"
,
torch
::
kCUDA
,
&
merge_state_v2
);
/*
/*
* From csrc/allreduce
* From csrc/allreduce
*/
*/
...
...
sgl-kernel/setup_hip.py
View file @
d297cda2
...
@@ -50,6 +50,7 @@ sources = [
...
@@ -50,6 +50,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/moe/moe_topk_softmax_kernels.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/speculative/eagle_utils.cu"
,
"csrc/kvcacheio/transfer.cu"
,
"csrc/kvcacheio/transfer.cu"
,
"csrc/attention/merge_attn_states.cu"
,
]
]
cxx_flags
=
[
"-O3"
]
cxx_flags
=
[
"-O3"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment