"docs/source/vscode:/vscode.git/clone" did not exist on "8a07ab77376a99b7114d0850ff99331ed88a648e"
Unverified Commit 5bb0accb authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)

parent 8d463fe3
...@@ -43,7 +43,7 @@ include(FetchContent) ...@@ -43,7 +43,7 @@ include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
repo-cutlass repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG 5e497243f7ad13a2aa842143f9b10bbb23d98292 GIT_TAG e94e888df3551224738bfa505787b515eae8352f
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-cutlass) FetchContent_Populate(repo-cutlass)
......
...@@ -34,12 +34,7 @@ ...@@ -34,12 +34,7 @@
using namespace cute; using namespace cute;
template < template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
typename SchedulerType,
typename OutType,
typename TileShape,
typename ClusterShape,
typename ScaleGranularity>
void launch_sm90_fp8_blockwise_scaled_mm( void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out, torch::Tensor& out,
const torch::Tensor& a, const torch::Tensor& a,
...@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using LayoutD = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC; constexpr int AlignmentD = AlignmentC;
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{}); using ScaleTileShape = Shape<_1, _128, _128>;
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{}); using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm90; using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
...@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>; using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
OperatorClass, OperatorClass,
...@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
ArchTag, ArchTag,
OperatorClass, OperatorClass,
ElementA, ElementA,
LayoutA, cute::tuple<LayoutA, LayoutSFA>,
AlignmentA, AlignmentA,
ElementB, ElementB,
LayoutB, cute::tuple<LayoutB, LayoutSFB>,
AlignmentB, AlignmentB,
ElementAccumulator, ElementAccumulator,
TileShape, TileShape,
...@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm( ...@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
StrideC stride_c; StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr}; LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d}; typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
typename Gemm::Arguments args = { typename Gemm::Arguments args = {
...@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape( ...@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_b) { const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>; using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>; using ClusterShape = Shape<_1, _2, _1>;
using ScaleGranularity = Shape<_1, _128, _128>;
auto k = a.size(1); auto k = a.size(1);
auto n = b.size(1); auto n = b.size(1);
if (k > 3 * n) { if (k > 3 * n) {
launch_sm90_fp8_blockwise_scaled_mm< launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
cutlass::gemm::StreamKScheduler, out, a, b, scales_a, scales_b);
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
} else { } else {
launch_sm90_fp8_blockwise_scaled_mm< launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
cutlass::gemm::PersistentScheduler, out, a, b, scales_a, scales_b);
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment