Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7c866711
"src/vscode:/vscode.git/clone" did not exist on "0c7cb9a61391710e57dc15e7dcd2d2b6e39ecdce"
Unverified
Commit
7c866711
authored
Mar 12, 2025
by
Elfie Guo
Committed by
GitHub
Mar 12, 2025
Browse files
Support Blackwell Block Scale FP8 Gemm (#4278)
parent
10b544ae
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
207 additions
and
2 deletions
+207
-2
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+181
-1
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+20
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+6
-1
No files found.
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
7c866711
...
...
@@ -17,6 +17,8 @@
#include <cutlass/matrix_coord.h>
#include <cutlass/numeric_types.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/util/tensor_view_io.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
...
...
@@ -154,6 +156,141 @@ void launch_sm90_fp8_blockwise_scaled_mm(
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
PerSmTileShape
,
typename
EpilogueTileShape
,
typename
ScalesPerTile
,
int
TileSizeM_
=
128
,
class
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
>
void
launch_sm100_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
ElementB
=
ElementAB
;
using
ElementC
=
void
;
using
ElementD
=
OutType
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutC
=
LayoutD
;
// This means both SFA and SFB are column-major.
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementBlockScale
=
float
;
using
ElementCompute
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
cutlass
::
arch
::
OpClassTensorOp
,
PerSmTileShape
,
ClusterShape
,
EpilogueTileShape
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
int
n
=
b
.
size
(
1
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
scales_a_ptr
=
static_cast
<
float
*>
(
scales_a
.
data_ptr
());
auto
scales_b_ptr
=
static_cast
<
float
*>
(
scales_b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
using
StrideA
=
typename
GemmKernel
::
StrideA
;
using
StrideB
=
typename
GemmKernel
::
StrideB
;
using
StrideD
=
typename
GemmKernel
::
StrideD
;
using
StrideC
=
typename
GemmKernel
::
StrideD
;
StrideA
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
scales_a_ptr
,
layout_SFA
,
scales_b_ptr
,
layout_SFB
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
epilogue_args
.
thread
.
alpha
=
1.0
f
;
typename
GemmKernel
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
mainloop_args
,
epilogue_args
};
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
can_implement
))
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
auto
init_status
=
gemm_op
.
initialize
(
args
,
workspace
.
get
());
TORCH_CHECK
(
init_status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
init_status
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
status
=
gemm_op
.
run
(
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
,
cutlassGetStatusString
(
status
))
}
template
<
typename
OutType
>
void
sm90_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
...
...
@@ -166,6 +303,30 @@ void sm90_fp8_blockwise_dispatch_shape(
launch_sm90_fp8_blockwise_scaled_mm
<
OutType
,
TileShape
,
ClusterShape
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
template
<
typename
OutType
>
void
sm100_fp8_blockwise_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
)
{
if
(
a
.
size
(
0
)
<=
128
)
{
using
MmaTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_64
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_64
,
_1
,
_1
>
;
launch_sm100_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
else
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
PerSmTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
EpilogueTileShape
=
Shape
<
_128
,
_64
>
;
using
ScalesPerTile
=
Shape
<
_128
,
_1
,
_1
>
;
launch_sm100_fp8_blockwise_scaled_mm
<
OutType
,
MmaTileShape
,
PerSmTileShape
,
EpilogueTileShape
,
ScalesPerTile
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
torch
::
Tensor
fp8_blockwise_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
...
...
@@ -210,7 +371,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>
=
90
)
{
if
(
sm_version
=
=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
);
}
else
{
...
...
@@ -221,6 +382,25 @@ torch::Tensor fp8_blockwise_scaled_mm(
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if
(
sm_version
==
100
)
{
int64_t
original_rows
=
mat_a
.
size
(
0
);
torch
::
Tensor
mat_a_padded
=
pad_tensor
(
mat_a
,
/*alignment=*/
4
);
torch
::
Tensor
scales_a_padded
=
pad_tensor
(
scales_a
,
/*alignment=*/
4
,
/*col_major=*/
true
);
torch
::
Tensor
out_padded
=
torch
::
empty
({
mat_a_padded
.
size
(
0
),
mat_b
.
size
(
1
)},
out
.
options
());
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm100_fp8_blockwise_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
else
{
sm100_fp8_blockwise_dispatch_shape
<
cutlass
::
half_t
>
(
out_padded
,
mat_a_padded
,
mat_b
,
scales_a_padded
,
scales_b
);
}
return
out_padded
.
slice
(
0
,
0
,
original_rows
);
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
}
sgl-kernel/include/utils.h
View file @
7c866711
...
...
@@ -141,3 +141,23 @@ __device__ __forceinline__ float blockReduceMax(float max_value) {
return
max_value
;
}
#endif
// Pads to a multiple of `alignment` rows.
inline
torch
::
Tensor
pad_tensor
(
const
torch
::
Tensor
&
tensor
,
int64_t
alignment
=
4
,
bool
is_column_major
=
false
)
{
int64_t
rows
=
tensor
.
size
(
0
);
int64_t
cols
=
tensor
.
size
(
1
);
int64_t
pad_rows
=
(
alignment
-
(
rows
%
alignment
))
%
alignment
;
// Compute padding size
if
(
pad_rows
==
0
)
{
return
tensor
;
// Already aligned
}
torch
::
Tensor
padding
=
torch
::
zeros
({
pad_rows
,
cols
},
tensor
.
options
());
torch
::
Tensor
tensor_padded
=
torch
::
cat
({
tensor
,
padding
},
0
);
// Pad along rows
// Ensure column-major layout
if
(
is_column_major
)
{
return
tensor_padded
.
t
().
contiguous
().
t
();
}
return
tensor_padded
;
}
sgl-kernel/setup.py
View file @
7c866711
...
...
@@ -122,7 +122,6 @@ nvcc_flags = [
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-std=c++17"
,
"-use_fast_math"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
,
"-DCUTLASS_VERSIONS_GENERATED"
,
...
...
@@ -169,12 +168,16 @@ sources = [
enable_bf16
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_BF16"
,
"0"
)
==
"1"
enable_fp8
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_FP8"
,
"0"
)
==
"1"
enable_sm90a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM90A"
,
"0"
)
==
"1"
enable_sm100a
=
os
.
getenv
(
"SGL_KERNEL_ENABLE_SM100A"
,
"0"
)
==
"1"
cuda_version
=
_get_cuda_version
()
sm_version
=
_get_device_sm
()
if
torch
.
cuda
.
is_available
():
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
cuda_version
>=
(
12
,
8
)
and
sm_version
>=
100
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100,code=sm_100"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
sm_version
>=
80
:
...
...
@@ -183,6 +186,8 @@ else:
# compilation environment without GPU
if
enable_sm90a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
enable_sm100a
:
nvcc_flags
.
append
(
"-gencode=arch=compute_100a,code=sm_100a"
)
if
enable_fp8
:
nvcc_flags
.
extend
(
nvcc_flags_fp8
)
if
enable_bf16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment