Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
841810f2
Unverified
Commit
841810f2
authored
Aug 13, 2025
by
henryg
Committed by
GitHub
Aug 13, 2025
Browse files
[Perf] Tunings for SM100 FP8 CUTLASS kernel (#8818)
parent
733446dd
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
177 additions
and
27 deletions
+177
-27
sgl-kernel/benchmark/bench_fp8_gemm.py
sgl-kernel/benchmark/bench_fp8_gemm.py
+23
-3
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+126
-24
sgl-kernel/csrc/gemm/math.hpp
sgl-kernel/csrc/gemm/math.hpp
+28
-0
No files found.
sgl-kernel/benchmark/bench_fp8_gemm.py
View file @
841810f2
import
argparse
import
copy
import
itertools
from
typing
import
Optional
,
Tuple
import
torch
import
triton
from
sgl_kernel
import
fp8_scaled_mm
as
sgl_scaled_mm
from
sgl_kernel
import
sgl_per_tensor_quant_fp8
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
...
...
@@ -69,6 +71,21 @@ WEIGHT_SHAPES = {
}
def
sglang_scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
fp8_type_
:
torch
.
dtype
=
torch
.
float8_e4m3fn
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
fp8_type_
)
is_static
=
True
if
scale
is
None
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
is_static
=
False
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
)
return
output
,
scale
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
...
...
@@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
b
=
torch
.
ones
((
N
,
K
),
device
=
"cuda"
)
*
5.0
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
b_fp8
=
b_fp8
.
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
dtype
=
torch
.
float16
if
"fp16"
in
provider
else
torch
.
bfloat16
if
"vllm-fp8"
in
provider
:
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
b_fp8
=
b_fp8
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
),
quantiles
=
quantiles
,
)
elif
"sglang-fp8"
in
provider
:
a_fp8
,
scale_a_fp8
=
sglang_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
sglang_scaled_fp8_quant
(
b
,
scale_b
)
b_fp8
=
b_fp8
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
,
bias
=
None
...
...
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
841810f2
...
...
@@ -48,6 +48,7 @@ limitations under the License.
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "math.hpp"
#include "utils.h"
using
namespace
cute
;
...
...
@@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
CTAShape
=
Shape
<
_256
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
CTAShapeDefault
=
Shape
<
_256
,
_128
,
_64
>
;
using
ClusterShapeDefault
=
Shape
<
_2
,
_2
,
_1
>
;
using
CTAShape256
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape256
=
Shape
<
_2
,
_1
,
_1
>
;
using
CTAShape64
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape64
=
Shape
<
_1
,
_1
,
_1
>
;
using
CTAShape16
=
Shape
<
_64
,
_64
,
_128
>
;
using
ClusterShape16
=
Shape
<
_1
,
_4
,
_1
>
;
using
MainloopScheduleType
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileSchedulerType
=
void
;
...
...
@@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
// Gemm type with bias
using
BiasGemmDefault
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShapeDefault
,
ClusterShapeDefault
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
using
BiasGemm256
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape256
,
ClusterShape256
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
using
BiasGemm64
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape64
,
ClusterShape64
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
using
BiasGemm16
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape16
,
ClusterShape16
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
// Gemm type without bias
using
GemmDefault
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShapeDefault
,
ClusterShapeDefault
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
using
Gemm256
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape256
,
ClusterShape256
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
using
Gemm64
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape64
,
ClusterShape64
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
using
Gemm16
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape16
,
ClusterShape16
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
// next power of 2 (minimum 16)
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
bias
)
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
if
(
mp2
<=
16
)
{
// m in [1, 16]
return
launch_sm100_fp8_scaled_mm
<
BiasGemm16
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
mp2
<=
64
)
{
// m in (16, 64]
return
launch_sm100_fp8_scaled_mm
<
BiasGemm64
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
mp2
<=
256
)
{
// m in (64, 256]
return
launch_sm100_fp8_scaled_mm
<
BiasGemm256
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
// m in (256, inf]
return
launch_sm100_fp8_scaled_mm
<
BiasGemmDefault
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
else
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
if
(
mp2
<=
16
)
{
// m in [1, 16]
return
launch_sm100_fp8_scaled_mm
<
Gemm16
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
mp2
<=
64
)
{
// m in (16, 64]
return
launch_sm100_fp8_scaled_mm
<
Gemm64
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
if
(
mp2
<=
256
)
{
// m in (64, 256]
return
launch_sm100_fp8_scaled_mm
<
Gemm256
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
return
launch_sm100_fp8_scaled_mm
<
GemmDefault
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
}
...
...
sgl-kernel/csrc/gemm/math.hpp
0 → 100644
View file @
841810f2
#pragma once
#include <climits>
#include <iostream>
inline
constexpr
uint32_t
next_pow_2
(
uint32_t
const
num
)
{
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
template
<
typename
A
,
typename
B
>
static
inline
constexpr
auto
div_ceil
(
A
a
,
B
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template
<
typename
T
>
inline
constexpr
T
round_to_previous_multiple_of
(
T
a
,
T
b
)
{
return
a
%
b
==
0
?
a
:
(
a
/
b
)
*
b
;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template
<
typename
T
>
inline
constexpr
T
round_to_next_multiple_of
(
T
a
,
T
b
)
{
return
a
%
b
==
0
?
a
:
((
a
/
b
)
+
1
)
*
b
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment