Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7c866711
Unverified
Commit
7c866711
authored
Mar 12, 2025
by
Elfie Guo
Committed by
GitHub
Mar 12, 2025
Browse files
Support Blackwell Block Scale FP8 Gemm (#4278)
parent
10b544ae
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
207 additions
and
2 deletions
+207
-2
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+181
-1
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+20
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+6
-1
No files found.
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
7c866711
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
#include <cutlass/matrix_coord.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/util/tensor_view_io.h>
#include <torch/all.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <cute/tensor.hpp>
...
@@ -154,6 +156,141 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -154,6 +156,141 @@ void launch_sm90_fp8_blockwise_scaled_mm(
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
}
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
PerSmTileShape
,
typename
EpilogueTileShape
,
typename
ScalesPerTile
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
>
void
launch_sm100_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
ElementB
=
ElementAB
;
using
ElementC
=
void
;
using
ElementD
=
OutType
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutC
=
LayoutD
;
// This means both SFA and SFB are column-major.
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
cutlass
::
arch
::
OpClassTensorOp
,
PerSmTileShape
,
ClusterShape
,
EpilogueTileShape
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
scales_a_ptr
=
static_cast
<
float
*>
(
scales_a
.
data_ptr
());
auto
scales_b_ptr
=
static_cast
<
float
*>
(
scales_b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
using
StrideD
=
typename
GemmKernel
::
StrideD
;
using
StrideC
=
typename
GemmKernel
::
StrideD
;
StrideA
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
scales_a_ptr
,
layout_SFA
,
scales_b_ptr
,
layout_SFB
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
epilogue_args
.
thread
.
alpha
=
1.0
f
;
typename
GemmKernel
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
};
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
init_status
=
gemm_op
.
initialize
(
args
,
workspace
.
get
());
TORCH_CHECK
(
init_status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
init_status
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
...
@@ -166,6 +303,30 @@ void sm90_fp8_blockwise_dispatch_shape(
...
@@ -166,6 +303,30 @@ void sm90_fp8_blockwise_dispatch_shape(
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
template
<
typename
OutType
>
void
sm100_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
if
(
a
.
size
(
0
)
<=
128
)
{
using
MmaTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_64
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_64
,
_1
,
_1
>
;
launch_sm100_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
else
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_128
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_128
,
_1
,
_1
>
;
launch_sm100_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
mat_b
,
...
@@ -210,7 +371,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -210,7 +371,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>
=
90
)
{
if
(
sm_version
=
=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
else
{
}
else
{
...
@@ -221,6 +382,25 @@ torch::Tensor fp8_blockwise_scaled_mm(
...
@@ -221,6 +382,25 @@ torch::Tensor fp8_blockwise_scaled_mm(
#endif
#endif
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if
(
sm_version
==
100
)
{
int64_t
original_rows
=
mat_a
.
size
(
0
);
torch
::
Tensor
mat_a_padded
=
pad_tensor
(
mat_a
,
/*alignment=*/
4
);
torch
::
Tensor
scales_a_padded
=
pad_tensor
(
scales_a
,
/*alignment=*/
4
,
/*col_major=*/
true
);
torch
::
Tensor
out_padded
=
torch
::
empty
({
mat_a_padded
.
size
(
0
),
mat_b
.
size
(
1
)},
out
.
options
());
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm100_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
else
{
sm100_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
return
out_padded
.
slice
(
0
,
0
,
original_rows
);
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
}
sgl-kernel/include/utils.h
View file @
7c866711
...
@@ -141,3 +141,23 @@ __device__ __forceinline__ float blockReduceMax(float max_value) {
...
@@ -141,3 +141,23 @@ __device__ __forceinline__ float blockReduceMax(float max_value) {
return
max_value
;
return
max_value
;
}
}
#endif
#endif
// Pads to a multiple of `alignment` rows.
inline
torch
::
Tensor
pad_tensor
(
const
torch
::
Tensor
&
tensor
,
int64_t
alignment
=
4
,
bool
is_column_major
=
false
)
{
int64_t
rows
=
tensor
.
size
(
0
);
int64_t
cols
=
tensor
.
size
(
1
);
int64_t
pad_rows
=
(
alignment
-
(
rows
%
alignment
))
%
alignment
;
// Compute padding size
if
(
pad_rows
==
0
)
{
return
tensor
;
// Already aligned
}
torch
::
Tensor
padding
=
torch
::
zeros
({
pad_rows
,
cols
},
tensor
.
options
());
torch
::
Tensor
tensor_padded
=
torch
::
cat
({
tensor
,
padding
},
0
);
// Pad along rows
// Ensure column-major layout
if
(
is_column_major
)
{
return
tensor_padded
.
t
().
contiguous
().
t
();
}
return
tensor_padded
;
}
sgl-kernel/setup.py
View file @
7c866711
...
@@ -122,7 +122,6 @@ nvcc_flags = [
...
@@ -122,7 +122,6 @@ nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-std=c++17"
,
"-std=c++17"
,
"-use_fast_math"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
,
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
,
"-DCUTLASS_VERSIONS_GENERATED"
,
"-DCUTLASS_VERSIONS_GENERATED"
,
...
@@ -169,12 +168,16 @@ sources = [
...
@@ -169,12 +168,16 @@ sources = [
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm100a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM100A"
,
"0"
)
==
"1"
cuda_version
=
_get_cuda_version
()
cuda_version
=
_get_cuda_version
()
sm_version
=
_get_device_sm
()
sm_version
=
_get_device_sm
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
cuda_version
>=
(
12
,
8
)
and
sm_version
>=
100
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100,code=sm_100"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
if
sm_version
>=
90
:
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
sm_version
>=
80
:
if
sm_version
>=
80
:
...
@@ -183,6 +186,8 @@ else:
...
@@ -183,6 +186,8 @@ else:
# compilation environment without GPU
# compilation environment without GPU
if
enable_sm90a
:
if
enable_sm90a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_sm100a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
if
enable_fp8
:
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
enable_bf16
:
if
enable_bf16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment