Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad4e58bf
Unverified
Commit
ad4e58bf
authored
Mar 20, 2025
by
Shu Wang
Committed by
GitHub
Mar 20, 2025
Browse files
Support fp8 gemm for blackwell (#4558)
parent
bfb03c61
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
287 additions
and
0 deletions
+287
-0
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+287
-0
No files found.
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
ad4e58bf
...
...
@@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape(
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
EpilogueScheduleType
,
typename
TileSchedulerType
=
void
,
bool
WithBias
=
false
>
struct
DeviceGemmFp8RowwiseSm100
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
using
TileShape
=
CTAShape
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ElementComputeEpilogue
=
float
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>
,
cute
::
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
OutElementType
,
OutElementType
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementType
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementType
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
OutElementType
>::
value
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
Compute1MulAdd
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
OutElementType
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Compute1Mul
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
OutElementType
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute
=
typename
std
::
conditional_t
<
WithBias
,
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1MulAdd
,
ScaleA
,
EVTCompute0
,
Bias
>
,
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1Mul
,
ScaleA
,
EVTCompute0
>>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
// MMA type
using
ElementAccumulator
=
AccumElementType
;
// Epilogue types
using
ElementCompute
=
float
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
OutElementType
,
LayoutD
,
AlignmentD
,
EpilogueScheduleType
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementType
,
LayoutA
,
AlignmentA
,
ElementType
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
static_assert
(
std
::
is_same_v
<
Descriptor
,
ScaleA
>
||
std
::
is_same_v
<
Descriptor
,
ScaleB
>
||
std
::
is_same_v
<
Descriptor
,
Bias
>
);
return
Arguments
{
data_ptr
};
}
public:
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
=
std
::
nullopt
)
{
auto
a_args
=
args_from_tensor
<
ScaleA
,
float
>
(
a_scales
);
auto
b_args
=
args_from_tensor
<
ScaleB
,
float
>
(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}};
if
constexpr
(
WithBias
)
{
auto
bias_args
=
args_from_tensor
<
Bias
,
OutElementType
>
(
bias
.
value
());
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}};
}
else
{
return
ArgumentType
{
a_args
,
evt0_args
,
{}};
}
}
};
template
<
typename
GemmType
,
bool
WithBias
>
typename
GemmType
::
Gemm
::
Arguments
prepare_sm100_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
Gemm
=
typename
GemmType
::
Gemm
;
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementC
=
typename
Gemm
::
ElementC
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
StrideC
;
using
StrideAux
=
StrideC
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
ElementT
const
*
ptr_a
=
reinterpret_cast
<
ElementT
const
*>
(
a
.
data_ptr
());
ElementT
const
*
ptr_b
=
reinterpret_cast
<
ElementT
const
*>
(
b
.
data_ptr
());
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideAux
aux_stride
=
stride_d
;
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
ptr_a
,
stride_a
,
ptr_b
,
stride_b
};
typename
GemmKernel
::
ProblemShape
prob_shape
=
{
m
,
n
,
k
,
1
};
cutlass
::
KernelHardwareInfo
hw_info
;
typename
GemmKernel
::
TileSchedulerArguments
scheduler
=
{};
auto
ptr_c
=
static_cast
<
ElementOutput
*>
(
out
.
data_ptr
());
auto
prepare_epilogue_args
=
[
&
](
const
c10
::
optional
<
torch
::
Tensor
>&
bias
=
c10
::
nullopt
)
{
if
constexpr
(
WithBias
)
{
TORCH_CHECK
(
bias
.
has_value
(),
"Bias tensor is required but not provided."
);
return
typename
GemmKernel
::
EpilogueArguments
{
GemmType
::
prepare_args
(
scales_a
,
scales_b
,
bias
.
value
()),
ptr_c
,
stride_c
,
ptr_c
,
stride_d
};
}
else
{
return
typename
GemmKernel
::
EpilogueArguments
{
GemmType
::
prepare_args
(
scales_a
,
scales_b
),
ptr_c
,
stride_c
,
ptr_c
,
stride_d
};
}
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
prepare_epilogue_args
(
bias
),
hw_info
,
scheduler
};
return
args
;
}
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm100_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm100_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
typename
Gemm
::
Gemm
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
auto
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
template
<
typename
OutType
>
void
sm100_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
CTAShape
=
Shape
<
_256
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
MainloopScheduleType
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileSchedulerType
=
void
;
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
if
(
bias
)
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
template
<
typename
OutType
>
void
sm100_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
return
sm100_fp8_dispatch_bias
<
OutType
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
#endif
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
...
...
@@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm(
auto
sm_version
=
getSMVersion
();
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if
(
sm_version
>=
100
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm100_fp8_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
sm100_fp8_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
return
out
;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment