"examples/vscode:/vscode.git/clone" did not exist on "041f78ba1f07f49ade6a6e7d53a7fd4002586884"
Unverified Commit 5bb0accb authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)

parent 8d463fe3
......@@ -43,7 +43,7 @@ include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG 5e497243f7ad13a2aa842143f9b10bbb23d98292
GIT_TAG e94e888df3551224738bfa505787b515eae8352f
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-cutlass)
......
......@@ -34,12 +34,7 @@
using namespace cute;
template <
typename SchedulerType,
typename OutType,
typename TileShape,
typename ClusterShape,
typename ScaleGranularity>
template <typename SchedulerType, typename OutType, typename TileShape, typename ClusterShape>
void launch_sm90_fp8_blockwise_scaled_mm(
torch::Tensor& out,
const torch::Tensor& a,
......@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using LayoutD = cutlass::layout::RowMajor;
constexpr int AlignmentD = AlignmentC;
static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{});
static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{});
using ScaleTileShape = Shape<_1, _128, _128>;
using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{}));
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
......@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM, ScaleGranularityN>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
......@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
ArchTag,
OperatorClass,
ElementA,
LayoutA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
LayoutB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
TileShape,
......@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
StrideC stride_c;
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr};
LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb};
typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d};
typename Gemm::Arguments args = {
......@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
const torch::Tensor& scales_b) {
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using ScaleGranularity = Shape<_1, _128, _128>;
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::StreamKScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::StreamKScheduler, OutType, TileShape, ClusterShape>(
out, a, b, scales_a, scales_b);
} else {
launch_sm90_fp8_blockwise_scaled_mm<
cutlass::gemm::PersistentScheduler,
OutType,
TileShape,
ClusterShape,
ScaleGranularity>(out, a, b, scales_a, scales_b);
launch_sm90_fp8_blockwise_scaled_mm<cutlass::gemm::PersistentScheduler, OutType, TileShape, ClusterShape>(
out, a, b, scales_a, scales_b);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment