// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #pragma once #include #include #include #include "common.h" namespace cute { template struct DispatchInstruction; using _X = Underscore; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile, Int, _X>; }; #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) template struct DispatchInstruction { using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _16>; }; #endif template struct OperandTraits { // Primary template, use padded layout and default copy static constexpr int stride = K_inner ? K : N; static constexpr int padded = stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; using Layout = typename std::conditional< K_inner, Layout, Int>, Shape, _1>>, Layout, Int>, Shape<_1, Int>>>::type; using Copy = DefaultCopy; }; template struct OperandTraits<16, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<16, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<16, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 3, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = SM75_U16x8_LDSM_T; }; template struct OperandTraits<16, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 3, 3>{}, Layout, Stride<_1, _64>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = SM75_U16x8_LDSM_T; }; template struct OperandTraits<32, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_32, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<32, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<32, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 2, 3>{}, Layout, Stride<_1, _32>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<32, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 3>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = UniversalCopy; }; template struct OperandTraits<8, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 4, 3>{}, Layout, Stride<_64, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<8, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = SM75_U32x4_LDSM_N; }; template struct OperandTraits<64, N, K, true, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 0, 4>{}, Layout, Stride<_16, _1>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{})); using Copy = DefaultCopy; }; template struct OperandTraits<64, N, K, false, typename std::enable_if::type> { using LayoutAtom = decltype(composition( Swizzle<2, 2, 2>{}, Layout, Stride<_1, _16>>{})); using Layout = decltype(tile_to_shape(LayoutAtom{}, Shape, Int>{}, Step<_2, _1>{})); using Copy = DefaultCopy; }; template class GemmTensorOp { public: using A_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using B_type = typename std::conditional::value, tfloat32_t, A_type_raw>::type; using C_type = C_type_raw; using Instruction = DispatchInstruction; using OperandATraits = OperandTraits::value, M, K, !trans_A>; using OperandBTraits = OperandTraits::value, N, K, trans_B>; using SmemLayoutA = typename OperandATraits::Layout; using SmemLayoutB = typename OperandBTraits::Layout; using SmemCopyA = Copy_Atom; using SmemCopyB = Copy_Atom; using TileMma = TiledMMA, Int, _1>>, typename Instruction::MMA_Group>; template static CUTE_DEVICE auto remove_swizzle(Layout const &layout) { return layout; } // In fp16, when layout is KxN and n_warp is 1 and N % 64 == 0 // the original layout fail to compile, currently using this as a workaround template static CUTE_DEVICE auto remove_swizzle(ComposedLayout const &layout) { if constexpr (sizeof(A_type) == 2) return layout.layout_b(); else return layout; } static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); // when layout is KxN and n_warp is 1, there seem to be a bug, use this as a // workaround auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_rs(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast(pB)), SmemLayoutB{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); Tensor tCrB = thr_mma.partition_fragment_B(sB); Tensor tCsB = thr_copy_B.partition_S(sB); Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrA = make_tensor(make_rmem_ptr(reinterpret_cast(pA)), partition_shape_A(tiled_mma, Shape, Int>{})); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); copy(tiled_copy_B, tCsB(_, _, 0), tCrB_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); } } static CUTE_DEVICE void body_sr(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC) { const int tid = threadIdx.x; Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast(pA)), SmemLayoutA{}); TileMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tid); auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); Tensor tCrA = thr_mma.partition_fragment_A(sA); Tensor tCsA = thr_copy_A.partition_S(sA); Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); Tensor acc = make_tensor(make_rmem_ptr(reinterpret_cast(pC)), partition_shape_C(tiled_mma, Shape, Int>{})); Tensor tCrB = make_tensor(make_rmem_ptr(reinterpret_cast(pB)), partition_shape_B(tiled_mma, Shape, Int>{})); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); copy(tiled_copy_A, tCsA(_, _, 0), tCrA_copy_view(_, _, 0)); CUTE_UNROLL for (int k = 0; k < size<2>(tCrA); ++k) { if (k < size<2>(tCrA) - 1) { copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); } gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); } } }; } // namespace cute namespace tl { template CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body(pA, pB, accum); } template CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_rs(pA, pB, accum); } template CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { using MMA = cute::GemmTensorOp; MMA::body_sr(pA, pB, accum); } } // namespace tl