#pragma once #include "../common.h" #include #include #ifndef __CUDACC_RTC__ #include #include #endif namespace tl { #ifndef TL_ALWAYS_FALSE_V_DEFINED #define TL_ALWAYS_FALSE_V_DEFINED template inline constexpr bool always_false_v = false; #endif namespace detail { template struct MmaImplTraits { using DReg = std::remove_extent_t; using AReg = std::remove_extent_t; using BReg = std::remove_extent_t; using CReg = std::remove_extent_t; static constexpr int kDRegs = std::extent_v; static constexpr int kARegs = std::extent_v; static constexpr int kBRegs = std::extent_v; static constexpr int kCRegs = std::extent_v; }; template TL_DEVICE void call_fma_impl(typename MmaImplTraits::DReg *d, const typename MmaImplTraits::AReg *a, const typename MmaImplTraits::BReg *b, const typename MmaImplTraits::CReg *c, std::index_sequence, std::index_sequence, std::index_sequence, std::index_sequence) { Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...); } template TL_DEVICE void call_fma(typename MmaImplTraits::DReg *d, const typename MmaImplTraits::AReg *a, const typename MmaImplTraits::BReg *b, const typename MmaImplTraits::CReg *c) { call_fma_impl(d, a, b, c, std::make_index_sequence::kDRegs>{}, std::make_index_sequence::kARegs>{}, std::make_index_sequence::kBRegs>{}, std::make_index_sequence::kCRegs>{}); } template struct MmaDispatcher { using CRegType = void; using ARegType = void; using BRegType = void; static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *, const CRegType *) { static_assert(always_false_v>, "tl::mma_sync: unsupported configuration"); } }; #define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \ NValue, KValue, TransAValue, TransBValue, \ SaturateValue, ImplType) \ template <> \ struct MmaDispatcher { \ using Impl = ImplType; \ using Traits = MmaImplTraits; \ using CRegType = typename Traits::DReg; \ using ARegType = typename Traits::AReg; \ using BRegType = typename Traits::BReg; \ static_assert( \ std::is_same_v, \ "tl::mma_sync requires matching accumulator/output regs"); \ static TL_DEVICE void exec(CRegType *d, const ARegType *a, \ const BRegType *b, const CRegType *c) { \ call_fma(d, a, b, c); \ } \ }; // FP16 inputs (TN layout: A row-major, B column-major) TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true, false, cute::SM80_16x8x16_F16F16F16F16_TN) TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true, false, cute::SM80_16x8x16_F32F16F16F32_TN) // BF16 inputs TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true, false, cute::SM80_16x8x16_F32BF16BF16F32_TN) // INT8 inputs (k32) TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32S8S8S32_TN) TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32U8U8S32_TN) // INT4 inputs (k32) TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32S4S4S32_TN) TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false, cute::SM80_16x8x32_S32U4U4S32_TN) // FP8 inputs (k32) TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN) TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false, true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN) // TF32 inputs (FP32 math on Tensor Cores) // Support both k=4 and k=8 variants on SM80 TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 4, false, true, false, cute::SM80_16x8x4_F32TF32TF32F32_TN) TL_DEFINE_MMA_DISPATCHER(kTensorFloat32, kTensorFloat32, kFloat32, 16, 8, 8, false, true, false, cute::SM80_16x8x8_F32TF32TF32F32_TN) // FP64 inputs (DMMA: m8n8k4, TN layout) TL_DEFINE_MMA_DISPATCHER(kFloat64, kFloat64, kFloat64, 8, 8, 4, false, true, false, cute::SM80_8x8x4_F64F64F64F64_TN) #undef TL_DEFINE_MMA_DISPATCHER } // namespace detail template TL_DEVICE void mma_sync( typename detail::MmaDispatcher::CRegType *c, const typename detail::MmaDispatcher::ARegType *a, const typename detail::MmaDispatcher::BRegType *b) { using Dispatcher = detail::MmaDispatcher; static_assert(!std::is_void_v, "tl::mma_sync: unsupported configuration"); Dispatcher::exec(c, a, b, c); } } // namespace tl