#pragma once #include #include #include #include #include #include #include "common.h" namespace cute { using namespace SM90; namespace tl_wgmma { using namespace cutlass::gemm::collective::detail; // ss_smem_selector template class GemmTensorOp { public: using A_type = conditional_t::value, tfloat32_t, A_type_raw>; using B_type = conditional_t::value, tfloat32_t, B_type_raw>; using C_type = C_type_raw; static constexpr GMMA::Major GmmaMajorA = trans_A ? GMMA::Major::MN : GMMA::Major::K; static constexpr GMMA::Major GmmaMajorB = trans_B ? GMMA::Major::K : GMMA::Major::MN; using SmemLayoutAtomA = decltype(ss_smem_selector, Int>()); using SmemLayoutAtomB = decltype(ss_smem_selector, Int>()); using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, Shape, Int>{}, conditional_t, Step<_1, _2>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, Shape, Int>{}, conditional_t, Step<_2, _1>>{})); static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); template static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = make_tiled_mma( GMMA::ss_op_selector< A_type, B_type, C_type, Shape, Int, Int>, GmmaMajorA, GmmaMajorB>(), Layout, Int, _1>>{}); auto thr_mma = tiled_mma.get_thread_slice(tid); // Allocate registers for pipelining Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); warpgroup_fence_operand(acc); warpgroup_arrive(); if constexpr (clear_accum) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(acc); // warpgroup_fence_operand(acc); // warpgroup_arrive(); // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); // warpgroup_commit_batch(); // if constexpr (wg_wait >= 0) { warpgroup_wait(); } // warpgroup_fence_operand(acc); } template static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { // TODO: Move bar.sync out of body_rs // asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(num_warp_m * num_warp_n * // 32)); const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); auto tiled_mma = make_tiled_mma( GMMA::rs_op_selector< A_type, B_type, C_type, Shape, Int, Int>, GmmaMajorA, GmmaMajorB>(), Layout, Int, _1>>{}); auto thr_mma = tiled_mma.get_thread_slice(tid); // Allocate registers for pipelining Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); warpgroup_fence_operand(tCrA); warpgroup_fence_operand(acc); warpgroup_arrive(); if constexpr (clear_accum) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), acc); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(acc); warpgroup_fence_operand(tCrA); // warpgroup_fence_operand(acc); // warpgroup_arrive(); // gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc); // warpgroup_commit_batch(); // if constexpr (wg_wait >= 0) { warpgroup_wait(); } // warpgroup_fence_operand(acc); } }; } // namespace tl_wgmma namespace tl_mma { template struct DispatchInstruction; using _X = Underscore; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile, Int, _X>; }; #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _16>; }; #endif template struct OperandTraits { // Primary template, use padded layout and default copy static constexpr int stride = K_inner ? K : N; static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; using Layout = typename std::conditional< K_inner, Layout, Int>, Shape, _1>>, Layout, Int>, Shape<_1, Int>>>::type; using Copy = DefaultCopy; }; template struct OperandTraits<16, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<16, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<16, N, K, false, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<16, N, K, false, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<32, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<32, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<32, N, K, false, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<32, N, K, false, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<8, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<8, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename std::conditional::type; }; template struct OperandTraits<64, N, K, true, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = DefaultCopy; }; template struct OperandTraits<64, N, K, false, num_warp_n, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = DefaultCopy; }; template class GemmTensorOp { public: using A_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using B_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using C_type = C_type_raw; using Instruction = DispatchInstruction; using OperandATraits = OperandTraits::value, M, K, !trans_A, num_warp_m>; using OperandBTraits = OperandTraits::value, N, K, trans_B, num_warp_n>; using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout; using SmemCopyA = Copy_Atom; using SmemCopyB = Copy_Atom; using TileMma = TiledMMA, Int, _1>>, typename Instruction::MMA_Group>; template static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { return layout; } // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 // the original layout fail to compile, currently using this as a workaround template static CUTE_DEVICE auto remove_swizzle(ComposedLayout const &layout) { if constexpr (sizeof(A_type) == 2) return layout.layout_b(); else return layout; } static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a // workaround auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); if constexpr (clear_accum) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); if constexpr (clear_accum) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast(pB)), partition_shape_B(tiled_mma, Shape, Int>{})); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); if constexpr (clear_accum) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); } } }; } // namespace tl_mma } // namespace cute namespace tl { namespace tl_mma { template CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::tl_mma::GemmTensorOp; MMA::body(pA, pB, accum); } template CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::tl_mma::GemmTensorOp; MMA::body_rs(pA, pB, accum); } template CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::tl_mma::GemmTensorOp; MMA::body_sr(pA, pB, accum); } } // namespace tl_mma template TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { if constexpr (use_wgmma) { using MMA = cute::tl_wgmma::GemmTensorOp; MMA::body(pA, pB, accum); } else { using MMA = cute::tl_mma::GemmTensorOp; MMA::body(pA, pB, accum); } } template TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { if constexpr (use_wgmma) { using MMA = cute::tl_wgmma::GemmTensorOp; MMA::body_rs(pA, pB, accum); } else { using MMA = cute::tl_mma::GemmTensorOp; MMA::body_rs(pA, pB, accum); } } template TL_DEVICE void wait_wgmma() { cute::warpgroup_wait(); } template TL_DEVICE void warp_scheduler_barrier_sync() { cutlass::arch::NamedBarrier::sync(NumMmaThreads, cutlass::canonical_warp_group_idx() /*id*/); } template TL_DEVICE void warp_scheduler_barrier_arrive() { static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); if constexpr (NumMmaThreads == 256) { cutlass::arch::NamedBarrier::arrive( NumMmaThreads, (1 - cutlass::canonical_warp_group_idx()) /*id*/); } else { cutlass::arch::NamedBarrier::arrive( NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); cutlass::arch::NamedBarrier::arrive( NumMmaThreads, (cutlass::canonical_warp_group_idx() <= 0 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); } } template TL_DEVICE void mma_init() { static_assert(NumMmaThreads == 256 || NumMmaThreads == 384); if (cutlass::canonical_warp_group_idx() > 0) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 0); } if constexpr (NumMmaThreads == 384) { if (cutlass::canonical_warp_group_idx() > 1) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 1 /*id*/); } } } } // namespace tl