Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd1e2689
Unverified
Commit
dd1e2689
authored
Sep 07, 2025
by
Jianying
Committed by
GitHub
Sep 06, 2025
Browse files
CUTLASS fp8 blockwise gemm support of sm120 (#9969)
parent
9a7ced4e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
0 deletions
+185
-0
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+185
-0
No files found.
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
dd1e2689
...
@@ -195,6 +195,176 @@ void sm100_fp8_blockwise_dispatch_shape(
...
@@ -195,6 +195,176 @@ void sm100_fp8_blockwise_dispatch_shape(
}
}
}
}
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
PerSmTileShape
,
typename
EpilogueTileShape
,
typename
ScalesPerTile
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
>
void
launch_sm120_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementBlockScale
=
float
;
// A matrix configuration
using
ElementA
=
cutlass
::
float_e4m3_t
;
// Element type for A matrix operand
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for A matrix operand
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Memory access granularity/alignment of A matrix in units of
// elements (up to 16 bytes)
// B matrix configuration
using
ElementB
=
cutlass
::
float_e4m3_t
;
// Element type for B matrix operand
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
// Layout type for B matrix operand
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Memory access granularity/alignment of B matrix in units of
// elements (up to 16 bytes)
// C/D matrix configuration
using
ElementD
=
OutType
;
// Element type for D matrix operand
using
ElementC
=
void
;
// Element type for C matrix operand
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for C matrix operand
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for D matrix operand
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Memory access granularity/alignment of C matrix in units of
// elements (up to 16 bytes)
constexpr
int
AlignmentC
=
AlignmentD
;
// Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using
ElementAccumulator
=
float
;
// Element type for internal accumulation
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
// Tag indicating the minimum SM that supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag - changed from OpClassBlockScaledTensorOp
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
using
ScaleConfig
=
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
// FP8 Block-wise scaling configuration
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
// Layout type for SFA matrix operand
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
// Layout type for SFB matrix operand
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
// Epilogue schedule policy
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutATag
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutBTag
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
// Kernel schedule policy. Auto defaults to cooperative kernel
// schedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementB
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
scales_a_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_a
.
data_ptr
());
auto
scales_b_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_b
.
data_ptr
());
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideD
;
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
scales_a_ptr
,
layout_SFA
,
scales_b_ptr
,
layout_SFB
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
c_ptr
,
stride_c
,
c_ptr
,
stride_c
};
epilogue_args
.
thread
.
alpha
=
1.0
f
;
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
,
};
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
init_status
=
gemm_op
.
initialize
(
args
,
workspace
.
get
());
TORCH_CHECK
(
init_status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
init_status
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
void
sm120_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_128
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_128
,
_1
,
_1
>
;
launch_sm120_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
mat_b
,
...
@@ -275,6 +445,21 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -275,6 +445,21 @@ torch::Tensor fp8_blockwise_scaled_mm(
}
}
#endif
#endif
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM120A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
if
(
sm_version
==
120
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm120_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
else
{
sm120_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
return
out_padded
.
slice
(
0
,
0
,
original_rows
);
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment