Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
5c9c275b
"sgl-router/vscode:/vscode.git/clone" did not exist on "abb6781573a86c7e7b22e41fd2924094a7d4a135"
Unverified
Commit
5c9c275b
authored
Jul 27, 2025
by
Elfie Guo
Committed by
GitHub
Jul 27, 2025
Browse files
Use FlashInfer FP4 gemm. (#8241)
parent
bf0f448f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
230 additions
and
5 deletions
+230
-5
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+20
-5
sgl-kernel/benchmark/bench_fp4_gemm.py
sgl-kernel/benchmark/bench_fp4_gemm.py
+210
-0
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
100644 → 100755
View file @
5c9c275b
...
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
...
@@ -35,10 +35,20 @@ if TYPE_CHECKING:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
sgl_kernel
import
scaled_fp4_quant
try
:
from
flashinfer
import
mm_fp4
as
fp4_gemm
enable_flashinfer_fp4_gemm
=
True
except
ImportError
:
if
is_cuda
():
from
sgl_kernel
import
cutlass_scaled_fp4_mm
as
fp4_gemm
else
:
fp4_gemm
=
None
enable_flashinfer_fp4_gemm
=
False
try
:
try
:
from
flashinfer
import
fp4_quantize
as
fp4_quantize
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
except
ImportError
:
except
ImportError
:
flashinfer_cutlass_fused_moe
=
None
flashinfer_cutlass_fused_moe
=
None
...
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
...
@@ -683,11 +693,16 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight_scale_interleaved
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
assert
layer
.
alpha
.
dtype
==
torch
.
float32
out
=
cutlass_scaled_fp4_mm
(
w
=
layer
.
weight
w_scale_interleaved
=
layer
.
weight_scale_interleaved
if
enable_flashinfer_fp4_gemm
:
w
=
layer
.
weight
.
T
w_scale_interleaved
=
layer
.
weight_scale_interleaved
.
T
out
=
fp4_gemm
(
x_fp4
,
x_fp4
,
layer
.
weight
,
w
,
x_scale_interleaved
,
x_scale_interleaved
,
layer
.
weight
_scale_interleaved
,
w
_scale_interleaved
,
layer
.
alpha
,
layer
.
alpha
,
output_dtype
,
output_dtype
,
)
)
...
...
sgl-kernel/benchmark/bench_fp4_gemm.py
0 → 100755
View file @
5c9c275b
import
argparse
import
copy
import
csv
import
itertools
import
pytest
import
torch
import
triton
from
flashinfer
import
mm_fp4
from
sgl_kernel
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
get_weight_shapes
(
args
):
models_tps
=
args
.
tp_sizes
if
models_tps
==
[
4
]:
return
[[
1024
,
3584
],
[
7168
,
256
],
[
7168
,
2304
],
[
9216
,
3584
]]
if
models_tps
==
[
8
]:
return
[[
512
,
3584
],
[
7168
,
128
],
[
7168
,
1152
],
[
4608
,
3584
]]
return
[
[
1024
,
3584
],
[
7168
,
256
],
[
7168
,
2304
],
[
9216
,
3584
],
[
512
,
3584
],
[
7168
,
128
],
[
7168
,
1152
],
[
4608
,
3584
],
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
3072
,
4096
,
8192
,
16384
,
],
# x_vals = [64],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"cutlass"
,
"cudnn"
,
"trtllm"
],
line_names
=
[
"baseline cutlass fp4"
,
"cudnn fp4"
,
"trtllm fp4"
],
styles
=
[(
"red"
,
"solid"
),
(
"blue"
,
"solid"
),
(
"green"
,
"solid"
)],
ylabel
=
"latency (ms)"
,
plot_name
=
"fp4_gemm_benchmark"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
dtype
,
correctness
,
csv_file
):
M
=
batch_size
packed_k
=
K
K
=
2
*
packed_k
a_dtype
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
b_dtype
=
torch
.
randn
((
N
,
K
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
# print("a_fp4", a_fp4)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
res_fi
=
torch
.
empty
((
M
,
N
),
dtype
=
dtype
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"cutlass"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
),
quantiles
=
quantiles
,
)
if
provider
==
"cudnn"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
),
quantiles
=
quantiles
,
)
if
provider
==
"trtllm"
:
a_scale_interleaved
=
a_scale_interleaved
.
to
(
torch
.
uint8
)
b_scale_interleaved
=
b_scale_interleaved
.
to
(
torch
.
uint8
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"trtllm"
,
),
quantiles
=
quantiles
,
)
if
correctness
:
res_cutlass
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
dtype
)
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"cudnn"
,
)
assert
torch
.
allclose
(
res_fi
,
res_cutlass
,
atol
=
1e-3
,
rtol
=
1e-3
),
"cudnn fp4 doesn't match cutlass fp4"
mm_fp4
(
a_fp4
,
b_fp4
.
T
,
a_scale_interleaved
,
b_scale_interleaved
.
T
,
alpha
,
dtype
,
res_fi
,
backend
=
"trtllm"
,
)
assert
torch
.
allclose
(
res_fi
,
res_cutlass
,
atol
=
1e-3
,
rtol
=
1e-3
),
"trtllm fp4 doesn't match cutlass fp4"
if
csv_file
:
with
open
(
csv_file
,
"a"
,
newline
=
""
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
provider
,
M
,
N
,
K
,
ms
])
return
ms
,
min_ms
,
max_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
torch
.
dtype
,
default
=
torch
.
bfloat16
,
help
=
"Data type"
,
)
parser
.
add_argument
(
"--correctness"
,
action
=
"store_true"
,
help
=
"Check correctness"
,
)
parser
.
add_argument
(
"--csv"
,
type
=
str
,
default
=
"results_cutlass_cudnn.csv"
,
help
=
"CSV file to save results"
,
)
args
=
parser
.
parse_args
()
if
args
.
csv
:
with
open
(
args
.
csv
,
"w"
,
newline
=
""
)
as
f
:
writer
=
csv
.
writer
(
f
)
writer
.
writerow
([
"provider"
,
"m"
,
"n"
,
"k"
,
"time_ms"
])
NKs
=
get_weight_shapes
(
args
)
for
N
,
K
in
NKs
:
print
(
f
"DeepSeek-R1-0528-FP4 N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp4_res"
,
N
=
N
,
K
=
K
,
dtype
=
args
.
dtype
,
correctness
=
args
.
correctness
,
csv_file
=
args
.
csv
,
)
print
(
"Benchmark finished!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment