Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
04b35190
Unverified
Commit
04b35190
authored
Jun 29, 2025
by
Ke Bao
Committed by
GitHub
Jun 29, 2025
Browse files
Add dsv3 fused a gemm to sgl-kernel (#7630)
parent
071a1f51
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
800 additions
and
0 deletions
+800
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
+57
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-0
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
+672
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+17
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-0
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
+32
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
04b35190
...
...
@@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
argparse
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.testing
from
sgl_kernel
import
dsv3_fused_a_gemm
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch (bf16)"
,
"dsv3_fused_a_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"bf16 dsv3 fused a GEMM throughput"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
impl
):
kHdIn
=
7168
kHdOut
=
2112
M
,
K
,
N
=
num_tokens
,
kHdIn
,
kHdOut
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
.
T
)
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
M
*
K
*
N
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_gemm"
)
sgl-kernel/csrc/common_extension.cc
View file @
04b35190
...
...
@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"
);
m
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
m
.
def
(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
m
.
impl
(
"dsv3_fused_a_gemm"
,
torch
::
kCUDA
,
&
dsv3_fused_a_gemm
);
// Compute NVFP4 experts quantization.
m
.
def
(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
...
...
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
0 → 100644
View file @
04b35190
This diff is collapsed.
Click to expand it.
sgl-kernel/include/sgl_kernel_ops.h
View file @
04b35190
...
...
@@ -201,6 +201,8 @@ void bmm_fp8(
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
/*
* From csrc/moe
*/
...
...
sgl-kernel/include/utils.h
View file @
04b35190
...
...
@@ -241,6 +241,23 @@ inline int getSMVersion() {
return
sm_major
*
10
+
sm_minor
;
}
inline
bool
getBoolEnv
(
char
const
*
name
)
{
char
const
*
env
=
std
::
getenv
(
name
);
return
env
&&
env
[
0
]
==
'1'
&&
env
[
1
]
==
'\0'
;
}
inline
bool
getEnvEnablePDL
()
{
static
std
::
once_flag
flag
;
static
bool
enablePDL
=
false
;
std
::
call_once
(
flag
,
[
&
]()
{
if
(
getSMVersion
()
>=
90
)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL
=
getBoolEnv
(
"TRTLLM_ENABLE_PDL"
);
}
});
return
enablePDL
;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
04b35190
...
...
@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
awq_dequantize
,
bmm_fp8
,
cutlass_scaled_fp4_mm
,
dsv3_fused_a_gemm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
04b35190
...
...
@@ -82,6 +82,21 @@ def bmm_fp8(
return
out
def
dsv3_fused_a_gemm
(
mat_a
:
torch
.
Tensor
,
mat_b
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
None
:
output
=
torch
.
empty
(
(
mat_a
.
shape
[
0
],
mat_b
.
shape
[
1
]),
device
=
mat_a
.
device
,
dtype
=
mat_a
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
dsv3_fused_a_gemm
.
default
(
output
,
mat_a
,
mat_b
)
return
output
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
dsv3_fused_a_gemm
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
i
+
1
for
i
in
range
(
16
)])
def
test_dsv3_fused_a_gemm
(
num_tokens
):
kHdIn
=
7168
kHdOut
=
2112
mat_a
=
torch
.
randn
(
(
num_tokens
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
kHdOut
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
output
=
torch
.
empty
(
(
num_tokens
,
kHdOut
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
ref
=
F
.
linear
(
mat_a
,
mat_b
.
T
)
output
=
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
assert
torch
.
allclose
(
output
,
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Fused GEMM output mismatch with torch.nn.functional.linear reference"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment