#pragma once #include "../common.h" #include #include #ifndef __CUDACC_RTC__ #include #include #endif namespace tl { #ifndef TL_ALWAYS_FALSE_V_DEFINED #define TL_ALWAYS_FALSE_V_DEFINED template inline constexpr bool always_false_v = false; #endif namespace detail { template struct MajorValue { static constexpr auto value = IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; }; template struct ScaleInValue { static_assert(Scale == 1 || Scale == -1, "tl::wgmma requires scale factors of +1 or -1."); static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One : cute::SM90::GMMA::ScaleIn::Neg; }; template inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1); template struct CallWgmmaSS { using CReg = std::remove_extent_t; static constexpr int kCRegs = std::extent_v; static_assert(sizeof(CReg) == sizeof(uint32_t), "tl::wgmma_ss expects 32-bit accumulator registers."); template TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale, std::index_sequence) { Impl::fma(desc_a, desc_b, c[Idx]..., scale); } TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw, bool scale_out) { auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One : cute::SM90::GMMA::ScaleOut::Zero; auto c = reinterpret_cast(c_raw); Run(desc_a, desc_b, c, scale, std::make_index_sequence{}); } }; template struct CallWgmmaRS { using AReg = std::remove_extent_t; using CReg = std::remove_extent_t; static constexpr int kARegs = std::extent_v; static constexpr int kCRegs = std::extent_v; static_assert(sizeof(AReg) == sizeof(uint32_t), "tl::wgmma_rs expects 32-bit register operands for A."); static_assert(sizeof(CReg) == sizeof(uint32_t) || sizeof(CReg) == sizeof(float), "tl::wgmma_rs expects 32-bit accumulator registers."); template TL_DEVICE static void Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale, std::index_sequence, std::index_sequence) { Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale); } TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b, uint32_t *c_raw, bool scale_out) { auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One : cute::SM90::GMMA::ScaleOut::Zero; auto a = reinterpret_cast(a_raw); auto c = reinterpret_cast(c_raw); Run(a, desc_b, c, scale, std::make_index_sequence{}, std::make_index_sequence{}); } }; } // namespace detail template struct WgmmaSSImpl { static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleA"); static_assert(detail::IsValidScale, "tl::wgmma_ss: invalid scaleB"); TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) { static_assert(always_false_v>, "tl::wgmma_ss: unsupported configuration"); } }; template struct WgmmaRSImpl { static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleA"); static_assert(detail::IsValidScale, "tl::wgmma_rs: invalid scaleB"); TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) { static_assert(always_false_v>, "tl::wgmma_rs: unsupported configuration"); } }; #define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaSSImpl { \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ using Impl = \ cute::SM90::GMMA::ImplName::value, \ detail::MajorValue::value, \ detail::ScaleInValue::value, \ detail::ScaleInValue::value>; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaSSImpl { \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ using Impl = \ cute::SM90::GMMA::ImplName::value, \ detail::ScaleInValue::value>; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ ImplName) \ template \ struct WgmmaSSImpl { \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_ss: invalid scaleB"); \ static_assert(scaleA == 1 && scaleB == 1, \ "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \ using Impl = cute::SM90::GMMA::ImplName; \ TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaSS::exec(desc_a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaRSImpl { \ static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ using Impl = \ cute::SM90::GMMA::ImplName::value, \ detail::MajorValue::value, \ detail::ScaleInValue::value, \ detail::ScaleInValue::value>; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \ template \ struct WgmmaRSImpl { \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ using Impl = \ cute::SM90::GMMA::ImplName::value, \ detail::ScaleInValue::value>; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \ ImplName) \ template \ struct WgmmaRSImpl { \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleA"); \ static_assert(detail::IsValidScale, \ "tl::wgmma_rs: invalid scaleB"); \ static_assert(scaleA == 1 && scaleB == 1, \ "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \ using Impl = cute::SM90::GMMA::ImplName; \ TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \ uint32_t *c, bool scale_out) { \ detail::CallWgmmaRS::exec(a, desc_b, c, scale_out); \ } \ }; #define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \ OP(8) \ OP(16) \ OP(24) \ OP(32) \ OP(40) \ OP(48) \ OP(56) \ OP(64) \ OP(72) \ OP(80) \ OP(88) \ OP(96) \ OP(104) \ OP(112) \ OP(120) \ OP(128) \ OP(136) \ OP(144) \ OP(152) \ OP(160) \ OP(168) \ OP(176) \ OP(184) \ OP(192) \ OP(200) \ OP(208) \ OP(216) \ OP(224) \ OP(232) \ OP(240) \ OP(248) \ OP(256) #define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \ OP(8) \ OP(16) \ OP(24) \ OP(32) \ OP(48) \ OP(64) \ OP(80) \ OP(96) \ OP(112) \ OP(128) \ OP(144) \ OP(160) \ OP(176) \ OP(192) \ OP(208) \ OP(224) \ OP(240) \ OP(256) #define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \ TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ MMA_64x##N##x16_F16F16F16_SS) #define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \ TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ MMA_64x##N##x16_F32F16F16_SS) #define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \ TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ MMA_64x##N##x16_F32BF16BF16_SS) #define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ MMA_64x##N##x8_F32TF32TF32_SS_TN) #define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32S8S8_SS_TN) #define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32S8U8_SS_TN) #define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32U8S8_SS_TN) #define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32U8U8_SS_TN) #define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E4M3E4M3_SS_TN) #define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E4M3E4M3_SS_TN) #define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E4M3E5M2_SS_TN) #define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E4M3E5M2_SS_TN) #define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E5M2E4M3_SS_TN) #define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E5M2E4M3_SS_TN) #define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E5M2E5M2_SS_TN) #define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \ TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E5M2E5M2_SS_TN) TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN); #define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \ TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \ MMA_64x##N##x16_F16F16F16_RS) #define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \ TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \ MMA_64x##N##x16_F32F16F16_RS) #define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \ TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \ MMA_64x##N##x16_F32BF16BF16_RS) #define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \ MMA_64x##N##x8_F32TF32TF32_RS_TN) #define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32S8S8_RS_TN) #define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32S8U8_RS_TN) #define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32U8S8_RS_TN) #define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \ MMA_64x##N##x32_S32U8U8_RS_TN) #define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E4M3E4M3_RS_TN) #define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E4M3E4M3_RS_TN) #define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E4M3E5M2_RS_TN) #define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E4M3E5M2_RS_TN) #define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E5M2E4M3_RS_TN) #define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E5M2E4M3_RS_TN) #define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \ MMA_64x##N##x32_F16E5M2E5M2_RS_TN) #define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \ TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \ MMA_64x##N##x32_F32E5M2E5M2_RS_TN) TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN); TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN); TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN); #undef TL_WGMMA_DEFINE_F16_F16_F16_SS #undef TL_WGMMA_DEFINE_F16_F16_F32_SS #undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS #undef TL_WGMMA_DEFINE_F32_TF32_SS_TN #undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN #undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN #undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN #undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN #undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN #undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN #undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN #undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN #undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN #undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN #undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN #undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN #undef TL_WGMMA_DEFINE_F16_F16_F16_RS #undef TL_WGMMA_DEFINE_F16_F16_F32_RS #undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS #undef TL_WGMMA_DEFINE_F32_TF32_RS_TN #undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN #undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN #undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN #undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN #undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN #undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN #undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN #undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN #undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN #undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN #undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN #undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN #undef TL_WGMMA_FOREACH_N_FLOAT_MUL8 #undef TL_WGMMA_FOREACH_N_INT32_MUL8 #undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE #undef TL_WGMMA_DEFINE_SS_GENERAL #undef TL_WGMMA_DEFINE_SS_TN #undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE #undef TL_WGMMA_DEFINE_RS_GENERAL #undef TL_WGMMA_DEFINE_RS_TN template TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, bool scale_out) { WgmmaSSImpl::execute(desc_a, desc_b, c, scale_out); } template TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c, bool scale_out) { WgmmaRSImpl::execute(a, desc_b, c, scale_out); } } // namespace tl