Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
dd1e2689
"llm/llama.cpp/gen_common.sh" did not exist on "5e7fd6906f4653fa671aa5d2e2d4dd5bdf17fd36"
Unverified
Commit
dd1e2689
authored
Sep 07, 2025
by
Jianying
Committed by
GitHub
Sep 06, 2025
Browse files
CUTLASS fp8 blockwise gemm support of sm120 (#9969)
parent
9a7ced4e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
0 deletions
+185
-0
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+185
-0
No files found.
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
dd1e2689
...
@@ -195,6 +195,176 @@ void sm100_fp8_blockwise_dispatch_shape(
...
@@ -195,6 +195,176 @@ void sm100_fp8_blockwise_dispatch_shape(
}
}
}
}
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
PerSmTileShape
,
typename
EpilogueTileShape
,
typename
ScalesPerTile
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
>
void
launch_sm120_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
ElementBlockScale
=
float
;
// A matrix configuration
using
ElementA
=
cutlass
::
float_e4m3_t
;
// Element type for A matrix operand
using
LayoutATag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for A matrix operand
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
// Memory access granularity/alignment of A matrix in units of
// elements (up to 16 bytes)
// B matrix configuration
using
ElementB
=
cutlass
::
float_e4m3_t
;
// Element type for B matrix operand
using
LayoutBTag
=
cutlass
::
layout
::
ColumnMajor
;
// Layout type for B matrix operand
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
// Memory access granularity/alignment of B matrix in units of
// elements (up to 16 bytes)
// C/D matrix configuration
using
ElementD
=
OutType
;
// Element type for D matrix operand
using
ElementC
=
void
;
// Element type for C matrix operand
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for C matrix operand
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
// Layout type for D matrix operand
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
// Memory access granularity/alignment of C matrix in units of
// elements (up to 16 bytes)
constexpr
int
AlignmentC
=
AlignmentD
;
// Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using
ElementAccumulator
=
float
;
// Element type for internal accumulation
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
// Tag indicating the minimum SM that supports the intended feature
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
// Operator class tag - changed from OpClassBlockScaledTensorOp
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
using
ScaleConfig
=
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
// FP8 Block-wise scaling configuration
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
// Layout type for SFA matrix operand
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
// Layout type for SFB matrix operand
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSmTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
// Epilogue schedule policy
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutATag
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutBTag
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
// Kernel schedule policy. Auto defaults to cooperative kernel
// schedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementB
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
scales_a_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_a
.
data_ptr
());
auto
scales_b_ptr
=
static_cast
<
ElementBlockScale
*>
(
scales_b
.
data_ptr
());
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideD
;
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
scales_a_ptr
,
layout_SFA
,
scales_b_ptr
,
layout_SFB
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
c_ptr
,
stride_c
,
c_ptr
,
stride_c
};
epilogue_args
.
thread
.
alpha
=
1.0
f
;
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
,
};
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
init_status
=
gemm_op
.
initialize
(
args
,
workspace
.
get
());
TORCH_CHECK
(
init_status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
init_status
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
void
sm120_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_128
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_128
,
_1
,
_1
>
;
launch_sm120_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
mat_b
,
...
@@ -275,6 +445,21 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -275,6 +445,21 @@ torch::Tensor fp8_blockwise_scaled_mm(
}
}
#endif
#endif
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM120A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080
if
(
sm_version
==
120
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm120_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
else
{
sm120_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
return
out_padded
.
slice
(
0
,
0
,
original_rows
);
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment