#pragma once #include #include #include #include #include #include #include "common.h" #include "cuda_fp8.h" namespace cute { template struct DispatchInstruction; using _X = Underscore; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile, Int, _X>; }; #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _16>; }; #endif template struct SelectCopy { static constexpr int remainder = (N / num_warp_n) % 16; using type = std::conditional_t< remainder == 4 || remainder == 8 || remainder == 0, std::conditional_t< transpose, std::conditional_t< remainder == 4, SM75_U32x1_LDSM_N, std::conditional_t>, std::conditional_t< remainder == 4, SM75_U16x2_LDSM_T, std::conditional_t>>, DefaultCopy>; }; template struct OperandTraits { // Primary template, use padded layout and default copy static constexpr int stride = leading_dim; static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; using Layout = typename std::conditional< K_inner, Layout, Int>, Shape, _1>>, Layout, Int>, Shape<_1, Int>>>::type; using Copy = DefaultCopy; }; template struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<16, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape( LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<16, N, K, false, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape( LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<32, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape( LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<32, N, K, false, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape( LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<8, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = typename SelectCopy::type; }; template struct OperandTraits<64, N, K, true, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = DefaultCopy; }; template struct OperandTraits<64, N, K, false, num_warp_n, leading_dim, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape( LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = DefaultCopy; }; template class GemmTensorOp { public: using A_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using B_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using C_type = C_type_raw; using Instruction = DispatchInstruction; using OperandATraits = OperandTraits::value, M, K, !trans_A, num_warp_m, lda>; using OperandBTraits = OperandTraits::value, N, K, trans_B, num_warp_n, ldb>; using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout; using SmemCopyA = Copy_Atom; using SmemCopyB = Copy_Atom; using TileMma = TiledMMA, Int, _1>>, typename Instruction::MMA_Group>; template static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { return layout; } // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 // the original layout fail to compile, currently using this as a workaround template static CUTE_DEVICE auto remove_swizzle(ComposedLayout const &layout) { if constexpr (sizeof(A_type) == 2) return layout.layout_b(); else return layout; } template static CUTE_DEVICE auto get_region_tensor(Tensor &sa) { if constexpr (offset == 0) { return composition( sa, Layout, Int>, Stride<_1, typename std::conditional, Int>::type>>{}); } else { if constexpr (trans) { static_assert(offset % KK == 0, "Offset must be a multiple of K"); constexpr int offset_n = offset / KK; return flat_divide(sa, Shape, Int>{})(_, _, _0{}, Int{}); } else { static_assert(offset % NN == 0, "Offset must be a multiple of N"); constexpr int offset_n = offset / NN; return flat_divide(sa, Shape, Int>{})(_, _, Int{}, _0{}); } } } static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); // Tensor sA = composition(sA_all, Layout, Int>, // Stride<_1, typename std::conditional, // Int>::type>>{}); // Tensor sB = composition(sB_all, Layout, Int>, // Stride<_1, typename std::conditional, // Int>::type>>{}); Tensor sA = get_region_tensor(sA_all); Tensor sB = get_region_tensor(sB_all); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a // workaround auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sB_all = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); // Tensor sB = flat_divide(sB_all, Shape, Int>{})(_, _, _0{}, // _0{}); Tensor sB = get_region_tensor(sB_all); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA_all = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); // Tensor sA = flat_divide(sA_all, Shape, Int>{})(_, _, _0{}, // _0{}); Tensor sA = get_region_tensor(sA_all); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast(pB)), partition_shape_B(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); } } }; } // namespace cute namespace tl { template CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body(pA, pB, accum); } template CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_rs(pA, pB, accum); } template CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_sr(pA, pB, accum); } } // namespace tl