"tests/vscode:/vscode.git/clone" did not exist on "2e19ba802571a211b094c1f330215b74479dbb49"
Unverified Commit 9b876889 authored by Qi Yuhang's avatar Qi Yuhang Committed by GitHub
Browse files

Update CUTLASS. Refine KernelSchedule for fp8 (grouped) gemm. (#10491)

parent c0c6f543
......@@ -46,7 +46,7 @@ include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
GIT_TAG 57e3cfb47a2d9e0d46eb6335c3dc411498efa198
GIT_SHALLOW OFF
)
FetchContent_Populate(repo-cutlass)
......
......@@ -72,7 +72,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90AccFetch>;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
......
......@@ -463,7 +463,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>;
// TODO: Check Pingpong or Cooperative
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
......@@ -475,7 +475,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
......@@ -487,7 +487,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment