Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
04b35190
"vscode:/vscode.git/clone" did not exist on "75028bd7a3d0e938d1772bccc436d34fa41a847f"
Unverified
Commit
04b35190
authored
Jun 29, 2025
by
Ke Bao
Committed by
GitHub
Jun 29, 2025
Browse files
Add dsv3 fused a gemm to sgl-kernel (#7630)
parent
071a1f51
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
800 additions
and
0 deletions
+800
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-0
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
+57
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+3
-0
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
+672
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+2
-0
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+17
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+15
-0
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
+32
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
04b35190
...
...
@@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
...
...
sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
argparse
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.testing
from
sgl_kernel
import
dsv3_fused_a_gemm
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"num_tokens"
],
x_vals
=
[
i
+
1
for
i
in
range
(
16
)],
x_log
=
False
,
line_arg
=
"impl"
,
line_vals
=
[
"torch"
,
"sgl-kernel"
],
line_names
=
[
"torch (bf16)"
,
"dsv3_fused_a_gemm"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
)],
ylabel
=
"TFLOPs"
,
plot_name
=
"bf16 dsv3 fused a GEMM throughput"
,
args
=
{},
)
)
def
benchmark
(
num_tokens
,
impl
):
kHdIn
=
7168
kHdOut
=
2112
M
,
K
,
N
=
num_tokens
,
kHdIn
,
kHdOut
mat_a
=
torch
.
randn
((
M
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
N
,
K
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
impl
==
"torch"
:
def
runner
():
F
.
linear
(
mat_a
,
mat_b
.
T
)
elif
impl
==
"sgl-kernel"
:
def
runner
():
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
runner
,
quantiles
=
quantiles
)
def
tflops
(
t_ms
):
flops
=
2
*
M
*
K
*
N
return
flops
/
(
t_ms
*
1e-3
)
/
1e12
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_dsv3_gemm"
)
sgl-kernel/csrc/common_extension.cc
View file @
04b35190
...
...
@@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"
);
m
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
m
.
def
(
"dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"
);
m
.
impl
(
"dsv3_fused_a_gemm"
,
torch
::
kCUDA
,
&
dsv3_fused_a_gemm
);
// Compute NVFP4 experts quantization.
m
.
def
(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
...
...
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
0 → 100644
View file @
04b35190
This diff is collapsed.
Click to expand it.
sgl-kernel/include/sgl_kernel_ops.h
View file @
04b35190
...
...
@@ -201,6 +201,8 @@ void bmm_fp8(
int64_t
cublas_handle
,
int64_t
cuda_stream
);
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
/*
* From csrc/moe
*/
...
...
sgl-kernel/include/utils.h
View file @
04b35190
...
...
@@ -241,6 +241,23 @@ inline int getSMVersion() {
return
sm_major
*
10
+
sm_minor
;
}
inline
bool
getBoolEnv
(
char
const
*
name
)
{
char
const
*
env
=
std
::
getenv
(
name
);
return
env
&&
env
[
0
]
==
'1'
&&
env
[
1
]
==
'\0'
;
}
inline
bool
getEnvEnablePDL
()
{
static
std
::
once_flag
flag
;
static
bool
enablePDL
=
false
;
std
::
call_once
(
flag
,
[
&
]()
{
if
(
getSMVersion
()
>=
90
)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL
=
getBoolEnv
(
"TRTLLM_ENABLE_PDL"
);
}
});
return
enablePDL
;
}
// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28
#ifndef USE_ROCM
#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
04b35190
...
...
@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
awq_dequantize
,
bmm_fp8
,
cutlass_scaled_fp4_mm
,
dsv3_fused_a_gemm
,
fp8_blockwise_scaled_mm
,
fp8_scaled_mm
,
int8_scaled_mm
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
04b35190
...
...
@@ -82,6 +82,21 @@ def bmm_fp8(
return
out
def
dsv3_fused_a_gemm
(
mat_a
:
torch
.
Tensor
,
mat_b
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
None
:
output
=
torch
.
empty
(
(
mat_a
.
shape
[
0
],
mat_b
.
shape
[
1
]),
device
=
mat_a
.
device
,
dtype
=
mat_a
.
dtype
,
)
torch
.
ops
.
sgl_kernel
.
dsv3_fused_a_gemm
.
default
(
output
,
mat_a
,
mat_b
)
return
output
def
sgl_per_token_group_quant_fp8
(
input
:
torch
.
Tensor
,
output_q
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_dsv3_fused_a_gemm.py
0 → 100644
View file @
04b35190
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
dsv3_fused_a_gemm
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
i
+
1
for
i
in
range
(
16
)])
def
test_dsv3_fused_a_gemm
(
num_tokens
):
kHdIn
=
7168
kHdOut
=
2112
mat_a
=
torch
.
randn
(
(
num_tokens
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
mat_b
=
torch
.
randn
((
kHdOut
,
kHdIn
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
transpose
(
0
,
1
)
output
=
torch
.
empty
(
(
num_tokens
,
kHdOut
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
).
contiguous
()
ref
=
F
.
linear
(
mat_a
,
mat_b
.
T
)
output
=
dsv3_fused_a_gemm
(
mat_a
,
mat_b
)
assert
torch
.
allclose
(
output
,
ref
,
rtol
=
1e-2
,
atol
=
1e-3
),
"Fused GEMM output mismatch with torch.nn.functional.linear reference"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment