Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5c9c275b
Unverified
Commit
5c9c275b
authored
Jul 27, 2025
by
Elfie Guo
Committed by
GitHub
Jul 27, 2025
Browse files
Use FlashInfer FP4 gemm. (#8241)
parent
bf0f448f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
5 deletions
+230
-5
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+20
-5
sgl-kernel/benchmark/bench_fp4_gemm.py
sgl-kernel/benchmark/bench_fp4_gemm.py
+210
-0
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
100644 → 100755
View file @
5c9c275b
...
...
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
from
sglang.srt.layers.moe.topk
import
TopKOutput
if
is_cuda
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
sgl_kernel
import
scaled_fp4_quant
try
:
from
flashinfer
import
mm_fp4
as
fp4_gemm
enable_flashinfer_fp4_gemm
=
True
except
ImportError
:
if
is_cuda
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
as
fp4_gemm
else
:
fp4_gemm
=
None
enable_flashinfer_fp4_gemm
=
False
try
:
from
flashinfer
import
fp4_quantize
as
fp4_quantize
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
except
ImportError
:
flashinfer_cutlass_fused_moe
=
None
...
...
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
out
=
cutlass_scaled_fp4_mm
(
w
=
layer
.
weight
w_scale_interleaved
=
layer
.
weight_scale_interleaved
if
enable_flashinfer_fp4_gemm
:
w
=
layer
.
weight
.
T
w_scale_interleaved
=
layer
.
weight_scale_interleaved
.
T
out
=
fp4_gemm
(
x_fp4
,
layer
.
weight
,
w
,
x_scale_interleaved
,
layer
.
weight
_scale_interleaved
,
w
_scale_interleaved
,
layer
.
alpha
,
output_dtype
,
)
...
...
sgl-kernel/benchmark/bench_fp4_gemm.py
0 → 100755
View file @
5c9c275b
import
argparse
import
copy
import
csv
import
itertools
import
pytest
import
torch
import
triton
from
flashinfer
import
mm_fp4
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
get_weight_shapes
(
args
):
models_tps
=
args
.
tp_sizes
if
models_tps
==
[
4
]:
return
[[
1024
,
3584
],
[
7168
,
256
],
[
7168
,
2304
],
[
9216
,
3584
]]
if
models_tps
==
[
8
]:
return
[[
512
,
3584
],
[
7168
,
128
],
[
7168
,
1152
],
[
4608
,
3584
]]
return
[
[
1024
,
3584
],
[
7168
,
256
],
[
7168
,
2304
],
[
9216
,
3584
],
[
512
,
3584
],
[
7168
,
128
],
[
7168
,
1152
],
[
4608
,
3584
],
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
3072
,
4096
,
8192
,
16384
,
],
# x_vals = [64],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"cutlass"
,
"cudnn"
,
"trtllm"
],
line_names
=
[
"baseline cutlass fp4"
,
"cudnn fp4"
,
"trtllm fp4"
],
styles
=
[(
"red"
,
"solid"
),
(
"blue"
,
"solid"
),
(
"green"
,
"solid"
)],
ylabel
=
"latency (ms)"
,
plot_name
=
"fp4_gemm_benchmark"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
dtype
,
correctness
,
csv_file
):
M
=
batch_size
packed_k
=
K
K
=
2
*
packed_k
a_dtype
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
b_dtype
=
torch
.
randn
((
N
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
# print("a_fp4", a_fp4)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
res_fi
=
torch
.
empty
((
M
,
N
),
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"cutlass"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
),
quantiles
=
quantiles
,
)
if
provider
==
"cudnn"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
),
quantiles
=
quantiles
,
)
if
provider
==
"trtllm"
:
a_scale_interleaved
=
a_scale_interleaved
.
to
(
torch
.
uint8
)
b_scale_interleaved
=
b_scale_interleaved
.
to
(
torch
.
uint8
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"trtllm"
,
),
quantiles
=
quantiles
,
)
if
correctness
:
res_cutlass
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"cudnn"
,
)
assert
torch
.
allclose
(
res_fi
,
res_cutlass
,
atol
=
1e-3
,
rtol
=
1e-3
),
"cudnn fp4 doesn't match cutlass fp4"
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"trtllm"
,
)
assert
torch
.
allclose
(
res_fi
,
res_cutlass
,
atol
=
1e-3
,
rtol
=
1e-3
),
"trtllm fp4 doesn't match cutlass fp4"
if
csv_file
:
with
open
(
csv_file
,
"a"
,
newline
=
""
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
provider
,
M
,
N
,
K
,
ms
])
return
ms
,
min_ms
,
max_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
torch
.
dtype
,
default
=
torch
.
bfloat16
,
help
=
"Data type"
,
)
parser
.
add_argument
(
"--correctness"
,
action
=
"store_true"
,
help
=
"Check correctness"
,
)
parser
.
add_argument
(
"--csv"
,
type
=
str
,
default
=
"results_cutlass_cudnn.csv"
,
help
=
"CSV file to save results"
,
)
args
=
parser
.
parse_args
()
if
args
.
csv
:
with
open
(
args
.
csv
,
"w"
,
newline
=
""
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
"provider"
,
"m"
,
"n"
,
"k"
,
"time_ms"
])
NKs
=
get_weight_shapes
(
args
)
for
N
,
K
in
NKs
:
print
(
f
"DeepSeek-R1-0528-FP4 N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp4_res"
,
N
=
N
,
K
=
K
,
dtype
=
args
.
dtype
,
correctness
=
args
.
correctness
,
csv_file
=
args
.
csv
,
)
print
(
"Benchmark finished!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment