#include #include namespace tl { static int const kSparse = 2; template struct ShapeCheck { static constexpr bool value = false; }; template struct ShapeCheck { static constexpr bool value = (Shape::kM % 32 == 0) && (Shape::kN % 32 == 0) && (Shape::kK % 32 == 0); }; template struct ShapeCheck { static constexpr bool value = ShapeCheck::value; // Same as half }; template struct ShapeCheck { static constexpr bool value = (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); }; template struct ShapeCheck { static constexpr bool value = (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); }; // ref: // https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h template struct DispatchInstructionShape { static_assert(!std::is_same_v, "Unsupported type for DispatchInstructionShape"); }; template <> struct DispatchInstructionShape { using Shape = cutlass::gemm::GemmShape<16, 8, 32>; using Operator = cutlass::arch::OpMultiplyAdd; }; template <> struct DispatchInstructionShape { using Shape = cutlass::gemm::GemmShape<16, 8, 32>; using Operator = cutlass::arch::OpMultiplyAdd; }; // TODO: Not supported for now // template<> // struct DispatchInstructionShape { // using Shape = cutlass::gemm::GemmShape<16, 8, 16>; // using Operator = cutlass::arch::OpMultiplyAdd; // }; template <> struct DispatchInstructionShape { using Shape = cutlass::gemm::GemmShape<16, 8, 64>; using Operator = cutlass::arch::OpMultiplyAddSaturate; }; template <> struct DispatchInstructionShape { using Shape = cutlass::gemm::GemmShape<16, 8, 64>; using Operator = cutlass::arch::OpMultiplyAddSaturate; }; // TODO: Not supported for now // template<> // struct DispatchInstructionShape { // using Shape = cutlass::gemm::GemmShape<16, 8, 128>; // using Operator = cutlass::arch::OpMultiplyAddSaturate; // }; template struct DispatchSharedMemoryLayoutA; template struct DispatchSharedMemoryLayoutA { using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< cutlass::sizeof_bits::value, K / kSparse>; }; template struct DispatchSharedMemoryLayoutA { static int const Crosswise_A = cutlass::platform::min(int(128 / sizeof(T)), M); using SmemLayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< cutlass::sizeof_bits::value, Crosswise_A>; }; template struct DispatchSharedMemoryLayoutB; template struct DispatchSharedMemoryLayoutB { static_assert( cutlass::sizeof_bits::value != 8, "int8, uint8, float8 only support column major layout for matrix B"); static int const Crosswise_B = cutlass::platform::min(int(128 / sizeof(T)), N); using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< cutlass::sizeof_bits::value, Crosswise_B>; }; template struct DispatchSharedMemoryLayoutB { static int const kCrosswiseB = (K > (1024 / cutlass::sizeof_bits::value)) ? (1024 / cutlass::sizeof_bits::value) : K; using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< cutlass::sizeof_bits::value, kCrosswiseB>; }; template struct DispatchType { static_assert(std::is_same::value, "Unsupported dtype"); }; template <> struct DispatchType { using Type = cutlass::half_t; }; template <> struct DispatchType { using Type = cutlass::bfloat16_t; }; template <> struct DispatchType { using Type = uint8_t; }; template <> struct DispatchType { using Type = int8_t; }; template class GemmTensorOp { public: static_assert(Shape::kM % num_warp_m == 0); static_assert(Shape::kN % num_warp_n == 0); using ElementA = typename DispatchType::Type; using ElementB = typename DispatchType::Type; using ElementC = C_type_raw; static_assert(std::is_same_v, "A and B are not the same type"); static_assert(ShapeCheck::value, "Invalid shape for ElementA"); using LayoutA = typename std::conditional_t; using LayoutB = typename std::conditional_t; using LayoutC = cutlass::layout::RowMajor; using ThreadblockShape = Shape; using SmemLayoutA = typename DispatchSharedMemoryLayoutA::SmemLayoutA; using SmemLayoutB = typename DispatchSharedMemoryLayoutB::SmemLayoutB; using WarpShape = cutlass::gemm::GemmShape; using InstructionShape = typename DispatchInstructionShape::Shape; using Operator = typename DispatchInstructionShape::Operator; static_assert(WarpShape::kK % InstructionShape::kK == 0, "K dimension must be divisible by instruction shape K."); // instruction/warp config using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::SparseMma, cutlass::MatrixShape<1, 1>>; using MmaWarp = cutlass::gemm::warp::SparseMmaTensorOp; static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); using SmemLayoutE = typename MmaWarp::LayoutE; static_assert(std::is_same_v, "Meta data layout must be ColumnMajor for sparse mma."); // other traits using FragmentA = typename MmaWarp::FragmentA; using FragmentB = typename MmaWarp::FragmentB; using FragmentC = typename MmaWarp::FragmentC; using FragmentE = typename MmaWarp::FragmentE; using IteratorA = typename MmaWarp::IteratorA; using IteratorB = typename MmaWarp::IteratorB; using IteratorE = typename MmaWarp::IteratorE; using TensorRefA = typename IteratorA::TensorRef; using TensorRefB = typename IteratorB::TensorRef; using TensorRefE = typename IteratorE::TensorRef; using ElementE = typename TensorRefE::Element; static int const kElementsPerElementE = MmaWarp::kElementsPerElementE; static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); using ShapeA = cutlass::MatrixShape; using ShapeB = cutlass::MatrixShape; using ShapeE = cutlass::MatrixShape; static int constexpr kKgroups = WarpShape::kK / InstructionShape::kK; template static CUTLASS_DEVICE void body(A_type_raw *pA, E_type_raw *pE, B_type_raw *pB, FragmentC &accum, const int warp_idx_m, const int warp_idx_n, const int lane_id) { MmaWarp mma_op; FragmentA frag_a; FragmentB frag_b; FragmentE frag_e; const TensorRefA ref_A( (ElementA *)pA, MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); const TensorRefE ref_E( (ElementE *)pE, MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); const TensorRefB ref_B( (ElementB *)pB, MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn})); IteratorA iter_A(ref_A, lane_id); IteratorE iter_E(ref_E, lane_id); IteratorB iter_B(ref_B, lane_id); iter_A.add_tile_offset({warp_idx_m, 0}); iter_E.add_tile_offset({warp_idx_m, 0}); iter_B.add_tile_offset({0, warp_idx_n}); if constexpr (clear_accum) { accum.clear(); } CUTLASS_PRAGMA_UNROLL for (int k = 0; k < kKgroups; ++k) { iter_A.load(frag_a); iter_E.load(frag_e); iter_B.load(frag_b); ++iter_A; ++iter_E; ++iter_B; mma_op(accum, frag_a, frag_b, accum, frag_e); } } }; template TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { using MMA = GemmTensorOp, num_warp_m, num_warp_n, trans_A, trans_B, clear_accum, A_type, B_type, C_type>; using FragmentC = typename MMA::FragmentC; int warp_id = threadIdx.x / 32; int lane_id = threadIdx.x % 32; MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, warp_id / num_warp_m, lane_id); } } // namespace tl