Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad4e58bf
"examples/pytorch/ogb/line/utils.py" did not exist on "90d86fcbe3d87c6d26bc724db68bc891a3fa56cb"
Unverified
Commit
ad4e58bf
authored
Mar 20, 2025
by
Shu Wang
Committed by
GitHub
Mar 20, 2025
Browse files
Support fp8 gemm for blackwell (#4558)
parent
bfb03c61
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
287 additions
and
0 deletions
+287
-0
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
+287
-0
No files found.
sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
View file @
ad4e58bf
...
...
@@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape(
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
template
<
typename
ElementType
,
typename
OutElementType
,
typename
AccumElementType
,
typename
CTAShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
typename
EpilogueScheduleType
,
typename
TileSchedulerType
=
void
,
bool
WithBias
=
false
>
struct
DeviceGemmFp8RowwiseSm100
{
static_assert
(
std
::
is_same_v
<
ElementType
,
cutlass
::
float_e4m3_t
>
,
"ElementType must be FP8(e4m3)"
);
using
TileShape
=
CTAShape
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ElementComputeEpilogue
=
float
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>
,
cute
::
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementComputeEpilogue
,
ElementComputeEpilogue
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
OutElementType
,
OutElementType
,
cute
::
Stride
<
cute
::
Int
<
0
>
,
cute
::
Int
<
1
>
,
cute
::
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementType
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementType
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
OutElementType
>::
value
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
Compute1MulAdd
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
OutElementType
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
Compute1Mul
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
OutElementType
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute
=
typename
std
::
conditional_t
<
WithBias
,
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1MulAdd
,
ScaleA
,
EVTCompute0
,
Bias
>
,
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1Mul
,
ScaleA
,
EVTCompute0
>>
;
using
ArgumentType
=
typename
EVTCompute
::
Arguments
;
// MMA type
using
ElementAccumulator
=
AccumElementType
;
// Epilogue types
using
ElementCompute
=
float
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
OutElementType
,
LayoutD
,
AlignmentD
,
EpilogueScheduleType
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementType
,
LayoutA
,
AlignmentA
,
ElementType
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
template
<
typename
Descriptor
,
typename
T
>
static
auto
args_from_tensor
(
torch
::
Tensor
const
&
tensor
)
{
using
Arguments
=
typename
Descriptor
::
Arguments
;
auto
*
data_ptr
=
static_cast
<
T
*>
(
tensor
.
data_ptr
());
static_assert
(
std
::
is_same_v
<
Descriptor
,
ScaleA
>
||
std
::
is_same_v
<
Descriptor
,
ScaleB
>
||
std
::
is_same_v
<
Descriptor
,
Bias
>
);
return
Arguments
{
data_ptr
};
}
public:
static
ArgumentType
prepare_args
(
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
=
std
::
nullopt
)
{
auto
a_args
=
args_from_tensor
<
ScaleA
,
float
>
(
a_scales
);
auto
b_args
=
args_from_tensor
<
ScaleB
,
float
>
(
b_scales
);
typename
EVTCompute0
::
Arguments
evt0_args
{
b_args
,
{},
{}};
if
constexpr
(
WithBias
)
{
auto
bias_args
=
args_from_tensor
<
Bias
,
OutElementType
>
(
bias
.
value
());
return
ArgumentType
{
a_args
,
evt0_args
,
bias_args
,
{}};
}
else
{
return
ArgumentType
{
a_args
,
evt0_args
,
{}};
}
}
};
template
<
typename
GemmType
,
bool
WithBias
>
typename
GemmType
::
Gemm
::
Arguments
prepare_sm100_fp8_args
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
Gemm
=
typename
GemmType
::
Gemm
;
using
ElementT
=
typename
Gemm
::
ElementA
;
using
ElementC
=
typename
Gemm
::
ElementC
;
using
ElementOutput
=
typename
Gemm
::
ElementD
;
using
ElementComputeEpilogue
=
float
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
StrideC
;
using
StrideAux
=
StrideC
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
ElementT
const
*
ptr_a
=
reinterpret_cast
<
ElementT
const
*>
(
a
.
data_ptr
());
ElementT
const
*
ptr_b
=
reinterpret_cast
<
ElementT
const
*>
(
b
.
data_ptr
());
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideAux
aux_stride
=
stride_d
;
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
ptr_a
,
stride_a
,
ptr_b
,
stride_b
};
typename
GemmKernel
::
ProblemShape
prob_shape
=
{
m
,
n
,
k
,
1
};
cutlass
::
KernelHardwareInfo
hw_info
;
typename
GemmKernel
::
TileSchedulerArguments
scheduler
=
{};
auto
ptr_c
=
static_cast
<
ElementOutput
*>
(
out
.
data_ptr
());
auto
prepare_epilogue_args
=
[
&
](
const
c10
::
optional
<
torch
::
Tensor
>&
bias
=
c10
::
nullopt
)
{
if
constexpr
(
WithBias
)
{
TORCH_CHECK
(
bias
.
has_value
(),
"Bias tensor is required but not provided."
);
return
typename
GemmKernel
::
EpilogueArguments
{
GemmType
::
prepare_args
(
scales_a
,
scales_b
,
bias
.
value
()),
ptr_c
,
stride_c
,
ptr_c
,
stride_d
};
}
else
{
return
typename
GemmKernel
::
EpilogueArguments
{
GemmType
::
prepare_args
(
scales_a
,
scales_b
),
ptr_c
,
stride_c
,
ptr_c
,
stride_d
};
}
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
prepare_epilogue_args
(
bias
),
hw_info
,
scheduler
};
return
args
;
}
template
<
typename
Gemm
,
bool
WithBias
>
void
launch_sm100_fp8_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
args
=
prepare_sm100_fp8_args
<
Gemm
,
WithBias
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
typename
Gemm
::
Gemm
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
auto
can_implement
=
gemm_op
.
can_implement
(
args
);
TORCH_CHECK
(
can_implement
==
cutlass
::
Status
::
kSuccess
)
auto
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
TORCH_CHECK
(
status
==
cutlass
::
Status
::
kSuccess
)
}
template
<
typename
OutType
>
void
sm100_fp8_dispatch_bias
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
using
CTAShape
=
Shape
<
_256
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
MainloopScheduleType
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileSchedulerType
=
void
;
using
ElementInput
=
cutlass
::
float_e4m3_t
;
using
ElementOutput
=
OutType
;
using
AccumElementType
=
float
;
if
(
bias
)
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
true
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
true
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
else
{
using
Gemm
=
DeviceGemmFp8RowwiseSm100
<
ElementInput
,
ElementOutput
,
AccumElementType
,
CTAShape
,
ClusterShape
,
MainloopScheduleType
,
EpilogueScheduleType
,
TileSchedulerType
,
false
>
;
return
launch_sm100_fp8_scaled_mm
<
Gemm
,
false
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
}
template
<
typename
OutType
>
void
sm100_fp8_dispatch_shape
(
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
)
{
return
sm100_fp8_dispatch_bias
<
OutType
>
(
out
,
a
,
b
,
scales_a
,
scales_b
,
bias
);
}
#endif
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
...
...
@@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm(
auto
sm_version
=
getSMVersion
();
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if
(
sm_version
>=
100
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
sm100_fp8_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
else
{
sm100_fp8_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
bias
);
}
return
out
;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if
(
sm_version
>=
90
)
{
if
(
out_dtype
==
torch
::
kBFloat16
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment