Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
85ed8e0a
Unverified
Commit
85ed8e0a
authored
Sep 07, 2025
by
Qi Yuhang
Committed by
GitHub
Sep 06, 2025
Browse files
Optimize nvfp4 block scaled gemm kernel when M is small. (#10101)
parent
dd1e2689
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
114 additions
and
30 deletions
+114
-30
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+114
-30
No files found.
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
View file @
85ed8e0a
...
@@ -38,27 +38,74 @@ limitations under the License.
...
@@ -38,27 +38,74 @@ limitations under the License.
using
namespace
cute
;
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
//
Kernel Perf config
//
Config(half_t/bfloat16_t) for M <= 128
template
<
typename
T
>
template
<
typename
T
>
struct
KernelTraits
{
struct
KernelConfigM128
{
using
OutputType
=
T
;
using
MmaTileShape
=
Shape
<
_128
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
Shape
<
_128
,
_64
>
;
// Avoid register spilling
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmNvf4Sm100
;
const
static
dim3
preferred_cluster
;
const
static
dim3
fallback_cluster
;
};
template
<
typename
T
>
const
dim3
KernelConfigM128
<
T
>::
preferred_cluster
(
1
,
4
,
1
);
template
<
typename
T
>
const
dim3
KernelConfigM128
<
T
>::
fallback_cluster
(
1
,
2
,
1
);
// Config(half_t/bfloat16_t) for M <= 256
template
<
typename
T
>
struct
KernelConfigM256
{
using
OutputType
=
T
;
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
Shape
<
_128
,
_64
>
;
using
EpilogueTile
=
Shape
<
_128
,
_64
>
;
// Avoid register spilling
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmNvf4Sm100
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmNvf4Sm100
;
const
static
dim3
preferred_cluster
;
const
static
dim3
fallback_cluster
;
};
};
template
<
typename
T
>
const
dim3
KernelConfigM256
<
T
>::
preferred_cluster
(
2
,
4
,
1
);
template
<
typename
T
>
const
dim3
KernelConfigM256
<
T
>::
fallback_cluster
(
2
,
1
,
1
);
template
<
>
// Default config(half_t/bfloat16_t) for M > 256
struct
KernelTraits
<
float
>
{
template
<
typename
T
>
struct
KernelConfigDefault
{
using
OutputType
=
T
;
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
Shape
<
_128
,
_64
>
;
// Avoid register spilling
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmNvf4Sm100
;
const
static
dim3
preferred_cluster
;
const
static
dim3
fallback_cluster
;
};
template
<
typename
T
>
const
dim3
KernelConfigDefault
<
T
>::
preferred_cluster
(
4
,
4
,
1
);
template
<
typename
T
>
const
dim3
KernelConfigDefault
<
T
>::
fallback_cluster
(
2
,
1
,
1
);
struct
KernelConfigFp32
{
using
OutputType
=
float
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
EpilogueTile
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmNvf4Sm100
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmNvf4Sm100
;
const
static
dim3
preferred_cluster
;
const
static
dim3
fallback_cluster
;
};
};
const
dim3
KernelConfigFp32
::
preferred_cluster
=
dim3
(
1
,
4
,
1
);
const
dim3
KernelConfigFp32
::
fallback_cluster
=
dim3
(
1
,
2
,
1
);
template
<
typename
T
>
template
<
typename
KernelConfig
>
struct
Fp4GemmSm100
{
struct
Fp4GemmSm100
{
using
Config
=
KernelConfig
;
// For generating args
using
OutputType
=
typename
KernelConfig
::
OutputType
;
// A matrix configuration
// A matrix configuration
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
...
@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
...
@@ -70,8 +117,8 @@ struct Fp4GemmSm100 {
static
constexpr
int
AlignmentB
=
32
;
static
constexpr
int
AlignmentB
=
32
;
// C/D matrix configuration
// C/D matrix configuration
using
ElementD
=
T
;
using
ElementD
=
OutputType
;
using
ElementC
=
T
;
using
ElementC
=
OutputType
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
...
@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
...
@@ -82,15 +129,15 @@ struct Fp4GemmSm100 {
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
// Kernel Perf config
// Kernel Perf config
using
MmaTileShape
=
typename
Kernel
Traits
<
T
>
::
MmaTileShape
;
using
MmaTileShape
=
typename
Kernel
Config
::
MmaTileShape
;
using
ClusterShape
=
typename
Kernel
Traits
<
T
>
::
ClusterShape
;
using
ClusterShape
=
typename
Kernel
Config
::
ClusterShape
;
using
EpilogueTile
=
typename
Kernel
Traits
<
T
>
::
EpilogueTile
;
using
EpilogueTile
=
typename
Kernel
Config
::
EpilogueTile
;
using
EpilogueSchedule
=
typename
Kernel
Traits
<
T
>
::
EpilogueSchedule
;
using
EpilogueSchedule
=
typename
Kernel
Config
::
EpilogueSchedule
;
using
MainloopSchedule
=
typename
Kernel
Traits
<
T
>
::
MainloopSchedule
;
using
MainloopSchedule
=
typename
Kernel
Config
::
MainloopSchedule
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
cutlass
::
arch
::
OpClassTensorOp
,
OperatorClass
,
MmaTileShape
,
MmaTileShape
,
ClusterShape
,
ClusterShape
,
EpilogueTile
,
EpilogueTile
,
...
@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
...
@@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options(
layout_SFB
},
layout_SFB
},
{
// Epilogue arguments
{
// Epilogue arguments
{},
// epilogue.thread
{},
// epilogue.thread
static_cast
<
ElementD
const
*>
(
D
.
data_ptr
())
,
nullptr
,
stride_D
,
stride_D
,
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
static_cast
<
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
using
KernelConfig
=
typename
T
::
Config
;
arguments
.
hw_info
.
cluster_shape
=
dim3
(
1
,
4
,
1
);
arguments
.
hw_info
.
cluster_shape
=
KernelConfig
::
preferred_cluster
;
arguments
.
hw_info
.
cluster_shape_fallback
=
dim3
(
1
,
1
,
1
);
arguments
.
hw_info
.
cluster_shape_fallback
=
KernelConfig
::
fallback_cluster
;
}
else
{
arguments
.
hw_info
.
cluster_shape
=
dim3
(
4
,
4
,
1
);
arguments
.
hw_info
.
cluster_shape_fallback
=
dim3
(
2
,
1
,
1
);
}
return
arguments
;
return
arguments
;
}
}
...
@@ -210,11 +253,10 @@ void runGemm(
...
@@ -210,11 +253,10 @@ void runGemm(
int64_t
n
,
int64_t
n
,
int64_t
k
,
int64_t
k
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
typename
Fp4GemmSm100
<
T
>::
Gemm
gemm
;
typename
T
::
Gemm
gemm
;
auto
arguments
=
args_from_options
<
T
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
auto
arguments
=
args_from_options
<
Fp4GemmSm100
<
T
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
size_t
workspace_size
=
Fp4GemmSm100
<
T
>
::
Gemm
::
get_workspace_size
(
arguments
);
size_t
workspace_size
=
T
::
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
...
@@ -224,9 +266,51 @@ void runGemm(
...
@@ -224,9 +266,51 @@ void runGemm(
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
}
// Dispatch function to select appropriate config based on M
template
<
typename
OutType
>
void
cutlassFp4GemmDispatch
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
if
(
m
<=
128
)
{
// m in [1, 128]
runGemm
<
Fp4GemmSm100
<
KernelConfigM128
<
OutType
>>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
m
<=
256
)
{
// m in (128, 256]
runGemm
<
Fp4GemmSm100
<
KernelConfigM256
<
OutType
>>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
// m in (256, inf)
runGemm
<
Fp4GemmSm100
<
KernelConfigDefault
<
OutType
>>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
// Dispatch function to select appropriate config based on M
template
<
>
void
cutlassFp4GemmDispatch
<
float
>
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
runGemm
<
Fp4GemmSm100
<
KernelConfigFp32
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
#else
#else
template
<
typename
T
>
template
<
typename
T
>
void
runGemm
(
void
cutlassFp4GemmDispatch
(
at
::
Tensor
&
D
,
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
B
,
...
@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
...
@@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
runGemm
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
cutlassFp4GemmDispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runGemm
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
cutlassFp4GemmDispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
runGemm
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
cutlassFp4GemmDispatch
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment