#pragma once #include "common.h" #include "cuda_fp8.h" #include #include #include #include #include namespace cute { template struct DispatchInstruction; using _X = Underscore; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) struct SM89_16x8x32_F32F8F8F32_E4M3_TN { using DRegisters = float[4]; using ARegisters = uint32_t[4]; using BRegisters = uint32_t[2]; using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, uint32_t const &b0, uint32_t const &b1, float const &c0, float const &c1, float const &c2, float const &c3) { asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3)); } }; struct SM89_16x8x32_F32F8F8F32_E5M2_TN { using DRegisters = float[4]; using ARegisters = uint32_t[4]; using BRegisters = uint32_t[2]; using CRegisters = float[4]; CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, uint32_t const &a0, uint32_t const &a1, uint32_t const &a2, uint32_t const &a3, uint32_t const &b0, uint32_t const &b1, float const &c0, float const &c1, float const &c2, float const &c3) { asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3)); } }; // (T32,V1) -> (M8,N8) using SM80_8x4 = Layout, _1>, Stride, _0>>; // (T32,V2) -> (M8,N8) using SM80_8x8_Row = Layout, _2>, Stride, _8>>; // (T32,V4) -> (M8,N16) using SM80_8x16_Row = Layout, _4>, Stride, _8>>; // (T32,V4) -> (M16,N8) using SM80_16x8_Row = Layout, Shape<_2, _2>>, Stride, Stride<_16, _8>>>; template <> struct MMA_Traits { using ValTypeD = float; using ValTypeA = fp8_e4_t; using ValTypeB = fp8_e4_t; using ValTypeC = float; using Shape_MNK = Shape<_16, _8, _32>; using ThrID = Layout<_32>; using ALayout = Layout, Shape<_4, _2, _2>>, Stride, Stride<_16, _8, _256>>>; using BLayout = Layout, Shape<_4, _2>>, Stride, Stride<_8, _128>>>; using CLayout = SM80_16x8_Row; }; template <> struct MMA_Traits { using ValTypeD = float; using ValTypeA = fp8_e5_t; using ValTypeB = fp8_e5_t; using ValTypeC = float; using Shape_MNK = Shape<_16, _8, _32>; using ThrID = Layout<_32>; using ALayout = Layout, Shape<_4, _2, _2>>, Stride, Stride<_16, _8, _256>>>; using BLayout = Layout, Shape<_4, _2>>, Stride, Stride<_8, _128>>>; using CLayout = SM80_16x8_Row; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile, Int, _X>; }; #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _16>; }; #endif template struct OperandTraits { // Primary template, use padded layout and default copy static constexpr int stride = K_inner ? K : N; static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; using Layout = typename std::conditional< K_inner, Layout, Int>, Shape, _1>>, Layout, Int>, Shape<_1, Int>>>::type; using Copy = DefaultCopy; }; template struct OperandTraits<16, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<16, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<16, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = SM75_U16x8_LDSM_T; }; template struct OperandTraits<16, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = SM75_U16x8_LDSM_T; }; template struct OperandTraits<32, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<32, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<32, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<32, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<8, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<8, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<64, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = DefaultCopy; }; template struct OperandTraits<64, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = DefaultCopy; }; template class GemmTensorOp { public: using A_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using B_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using C_type = C_type_raw; using Instruction = DispatchInstruction; using OperandATraits = OperandTraits::value, M, K, !trans_A>; using OperandBTraits = OperandTraits::value, N, K, trans_B>; using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout; using SmemCopyA = Copy_Atom; using SmemCopyB = Copy_Atom; using TileMma = TiledMMA, Int, _1>>, typename Instruction::MMA_Group>; template static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { return layout; } // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 // the original layout fail to compile, currently using this as a workaround template static CUTE_DEVICE auto remove_swizzle(ComposedLayout const &layout) { if constexpr (sizeof(A_type) == 2) return layout.layout_b(); else return layout; } static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a // workaround auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast(pB)), partition_shape_B(tiled_mma, Shape, Int>{})); if constexpr (clear_accum) { clear(acc); } auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); } } }; } // namespace cute namespace tl { template CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body(pA, pB, accum); } template CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_rs(pA, pB, accum); } template CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_sr(pA, pB, accum); } } // namespace tl