Unverified Commit aef0a6bb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986)



* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c85bb3ac
Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c Subproject commit 1815c3e0b6ec4ead36370bbd1562025d8529017c
...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the ...@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View ### Timeline View
``` ```
generic initialize_descriptor → generic shared-store → async wgmma generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │ │ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy └─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑ │ fence inserted here ↑
...@@ -53,7 +53,7 @@ def kernel(): ...@@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
"float16", "float16",
...@@ -83,7 +83,7 @@ def kernel(): ...@@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1): with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor") desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared") smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32) T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0) smem[0] = T.float16(0)
T.fence_proxy_async() T.fence_proxy_async()
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
......
...@@ -546,6 +546,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -546,6 +546,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity, return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner); element_size, k_inner);
}) })
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout", .def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) { [](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size); return makeFullBankSwizzleLayout(stride, continuous, element_size);
......
...@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs) ...@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory) TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -219,6 +229,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait) ...@@ -219,6 +229,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx) TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -286,11 +301,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect) ...@@ -286,11 +301,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(initialize_descriptor) TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5) .set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2) .set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
...@@ -311,5 +331,10 @@ TIR_DEFINE_TL_BUILTIN(device_assert_with_msg) ...@@ -311,5 +331,10 @@ TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss(); ...@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*! /*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions. * \brief tvm intrinsics for ptx tensor core wgmma instructions.
* *
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool * void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm * bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr * StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool * B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* scale_out, bool scale_in_a, bool scale_in_b); * bool scale_in_a, bool scale_in_b);
*/ */
TVM_DLL const Op &ptx_wgmma_rs(); TVM_DLL const Op &ptx_wgmma_rs();
/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ss();
/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ts();
/*! /*!
* \brief tvm intrinsics for initializing tensor memory * \brief tvm intrinsics for initializing tensor memory
* *
...@@ -361,6 +371,14 @@ TVM_DLL const Op &warpgroup_commit_batch(); ...@@ -361,6 +371,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/ */
TVM_DLL const Op &warpgroup_wait(); TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();
/*! /*!
* \brief Return the canonical lane index for the calling thread. * \brief Return the canonical lane index for the calling thread.
* *
...@@ -494,7 +512,21 @@ TVM_DLL const Op &tl_shuffle_elect(); ...@@ -494,7 +512,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in * This op is used to represent a descriptor initialization operation in
* tilelang. * tilelang.
*/ */
TVM_DLL const Op &initialize_descriptor(); TVM_DLL const Op &initialize_wgmma_descriptor();
/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL const Op &initialize_tcgen05_descriptor();
/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL const Op &tcgen05_mma_arrive();
/*! /*!
* \brief tilelang intrinsic for setting the start address of a descriptor * \brief tilelang intrinsic for setting the start address of a descriptor
......
...@@ -12,77 +12,13 @@ ...@@ -12,77 +12,13 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../target/utils.h" #include "../target/utils.h"
#include "tcgen5_meta.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
/** /**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer * @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map. * map.
...@@ -186,6 +122,8 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const { ...@@ -186,6 +122,8 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const {
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target); bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target); bool allow_wgmma = AllowWGMMA(block_size, target);
LOG(INFO) << "allow_tcgen5mma: " << allow_tcgen5mma
<< ", allow_wgmma: " << allow_wgmma;
if (allow_tcgen5mma) { if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA; return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) { } else if (allow_wgmma) {
...@@ -195,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { ...@@ -195,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
} else if (TargetIsCuda(target)) { } else if (TargetIsCuda(target)) {
return GemmInst::kMMA; return GemmInst::kMMA;
} else { } else {
ICHECK(0) << "Unsupported target for gemm: " << target->str(); ICHECK(0) << "Unsupported target for gemm: " << target;
} }
} }
...@@ -578,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -578,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (A.scope() == "local.fragment") { if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment"); ICHECK(B.scope() != "local.fragment");
ICHECK(!trans_A)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs"; op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") { } else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr"; op_name = "tl::gemm_sr";
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include "../support/ffi_aliases.h" #include "../support/ffi_aliases.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "tcgen5_meta.h"
#include "tvm/ffi/string.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -49,7 +51,6 @@ using namespace tir; ...@@ -49,7 +51,6 @@ using namespace tir;
*/ */
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>(); ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0]; node->Aptr = args[0];
node->Bptr = args[1]; node->Bptr = args[1];
node->Cptr = args[2]; node->Cptr = args[2];
...@@ -76,6 +77,19 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) { ...@@ -76,6 +77,19 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) { if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value; node->wg_wait = args[15].as<IntImm>().value()->value;
} }
if (args.size() > 16) {
node->mbarptr = args[16];
} else {
node->mbarptr = IntImm(DataType::UInt(32), 0);
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
}
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -92,16 +106,37 @@ TileOperator GemmPyNode::Clone() const { ...@@ -92,16 +106,37 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op); return GemmPy(op);
} }
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target); int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size; int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
(num_warps % 4 == 0) && CheckWGMMA(); TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
if (allow_wgmma) { CheckWGMMA();
}
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA; return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA; return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) { } else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA; return GemmInst::kMMA;
} else { } else {
ICHECK(0) << "Unsupported target for gemm: " << target->str(); ICHECK(0) << "Unsupported target for gemm: " << target->str();
...@@ -290,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -290,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}); });
} }
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.get_tcgen5_mma_meta",
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
Array<Integer> result;
if (success) {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
}
return result;
});
refl::GlobalDef().def(
"tl.get_tcgen5_instr_desc",
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
int scale_in_b) {
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
c_dtype, a_is_k_major, b_is_k_major,
scale_in_a, scale_in_b);
return Integer(static_cast<int64_t>(desc));
});
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -19,6 +19,8 @@ using namespace tir; ...@@ -19,6 +19,8 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode { class GemmPyNode : public TileOperatorNode {
public: public:
bool CheckWGMMA() const; bool CheckWGMMA() const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
tir::Buffer A, B, C; tir::Buffer A, B, C;
// pointer to the A, B, C // pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr; PrimExpr Aptr, Bptr, Cptr;
...@@ -27,6 +29,8 @@ public: ...@@ -27,6 +29,8 @@ public:
int stride_A, stride_B; int stride_A, stride_B;
int offset_A, offset_B; int offset_A, offset_B;
PrimExpr clear_accum = const_false(); PrimExpr clear_accum = const_false();
PrimExpr mbarptr;
Array<PrimExpr> C_coords;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack = 1;
...@@ -54,6 +58,8 @@ public: ...@@ -54,6 +58,8 @@ public:
.def_ro("offset_A", &GemmPyNode::offset_A) .def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B) .def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum) .def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("mbarptr", &GemmPyNode::mbarptr)
.def_ro("C_coords", &GemmPyNode::C_coords)
.def_ro("kPack", &GemmPyNode::kPack) .def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait) .def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy); .def_ro("policy", &GemmPyNode::policy);
......
#ifndef TVM_TL_OP_TCGEN5_META_H_
#define TVM_TL_OP_TCGEN5_META_H_
#include <cstdint>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <utility>
#include <vector>
namespace tvm {
namespace tl {
using runtime::DataType;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k,
DataType ab_dtype, DataType c_dtype,
bool a_is_k_major, bool b_is_k_major,
int scale_in_a, int scale_in_b) {
ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16";
ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8";
ICHECK(atom_k == 16 || atom_k == 32)
<< "Unsupported atom_k for TCGEN5MMA descriptor: " << atom_k;
ICHECK(scale_in_a == 1 || scale_in_a == -1)
<< "scale_in_a must be +/-1 for TCGEN5MMA";
ICHECK(scale_in_b == 1 || scale_in_b == -1)
<< "scale_in_b must be +/-1 for TCGEN5MMA";
auto encode_dtype = [&](DataType dtype) -> uint32_t {
if (dtype.is_float16()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_bfloat16()) {
return static_cast<uint32_t>(1);
} else if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() ||
dtype.is_float8_e4m3()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) {
return static_cast<uint32_t>(1);
}
LOG(FATAL) << "Unsupported dtype for TCGEN5MMA descriptor: " << dtype;
return 0u;
};
uint32_t a_format = encode_dtype(ab_dtype);
uint32_t b_format = a_format;
uint32_t c_format = 0;
if (c_dtype.is_float16()) {
c_format = 0;
} else if (c_dtype.is_float()) {
c_format = 1;
} else if (c_dtype.is_int()) {
c_format = 2;
} else {
LOG(FATAL) << "Unsupported accumulator dtype for TCGEN5MMA descriptor: "
<< c_dtype;
}
auto set_bits = [](uint32_t value, int start, int width) -> uint32_t {
uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1);
return (value & mask) << start;
};
uint32_t desc = 0;
desc |= set_bits(0, 0, 2); // sparse_id2
desc |= set_bits(0, 2, 1); // sparse_flag
desc |= set_bits(0, 3, 1); // saturate
desc |= set_bits(c_format, 4, 2);
desc |= set_bits(a_format, 7, 3);
desc |= set_bits(b_format, 10, 3);
uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u;
uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u;
desc |= set_bits(a_neg, 13, 1);
desc |= set_bits(b_neg, 14, 1);
uint32_t a_major = a_is_k_major ? 0u : 1u;
uint32_t b_major = b_is_k_major ? 0u : 1u;
desc |= set_bits(a_major, 15, 1);
desc |= set_bits(b_major, 16, 1);
uint32_t n_dim = static_cast<uint32_t>(atom_n >> 3);
uint32_t m_dim = static_cast<uint32_t>(atom_m >> 4);
desc |= set_bits(n_dim, 17, 6);
desc |= set_bits(0, 23, 1);
desc |= set_bits(m_dim, 24, 5);
desc |= set_bits(0, 29, 1);
uint32_t max_shift = 0u;
desc |= set_bits(max_shift, 30, 2);
return desc;
}
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_TCGEN5_META_H_
...@@ -260,6 +260,18 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -260,6 +260,18 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) { if (need_mma_h_) {
decl_stream << "#include <mma.h>\n"; decl_stream << "#include <mma.h>\n";
} }
if (need_mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma.h>\n";
}
if (need_wgmma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/wgmma.h>\n";
}
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
}
if (need_tcgen05_common_h_) {
decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n";
}
if (enable_fp8_) { if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n"; decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
} }
...@@ -1277,7 +1289,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -1277,7 +1289,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
if (scope.empty()) { if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data); scope = GetPtrStorageScope(buffer->data);
} }
if (scope == "local.var" || scope == "local.descriptor") { if (scope == "local.var" || scope.find("local.descriptor") == 0) {
os << vid; os << vid;
return os.str(); return os.str();
} }
...@@ -1597,6 +1609,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1597,6 +1609,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma) this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n"; << ">();\n";
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
ICHECK_EQ(op->args.size(), 4U);
std::string dtype = Downcast<StringImm>(op->args[0])->value;
std::string data_ptr = this->PrintExpr(op->args[1]);
std::string offset = this->PrintExpr(op->args[2]);
std::string num_regs = this->PrintExpr(op->args[3]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype);
std::string cast_type = "uint32_t";
if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 ||
dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) {
cast_type = "float";
}
this->PrintIndent();
this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type
<< "*>(" << data_ptr << " + " << offset << "), " << num_regs
<< ");\n";
} else if (op->op.same_as(tl::set_max_nreg())) { } else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
...@@ -1708,14 +1736,43 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1708,14 +1736,43 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string b_bias = this->PrintExpr(op->args[9]); std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]); std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]); std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value; auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
std::string bit_op = auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : ""; auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
std::string asm_code = PrintMMAAssembly( auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); need_mma_instruction_h_ = true;
this->PrintIndent(); this->PrintIndent();
this->stream << asm_code; std::string mma_call =
"tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) { } else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX // arg 0: shape: mXnXkX
// arg 1: A layout: row/col // arg 1: A layout: row/col
...@@ -1792,6 +1849,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1792,6 +1849,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
scale_in_b, a_is_shared, "", "", "", false); scale_in_b, a_is_shared, "", "", "", false);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape); auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_wgmma_instruction_h_ = true;
std::string wgmma_asm_code = std::string wgmma_asm_code =
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), " "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), " "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
...@@ -1820,41 +1878,173 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1820,41 +1878,173 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
wgmma_asm_code = replacer.rewrite(wgmma_asm_code); wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
this->stream << wgmma_asm_code; this->stream << wgmma_asm_code;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) { } else if (op->op.same_as(tl::ptx_wgmma_rs())) {
// arg 0: dtype // arg 0: shape
// arg 1: shape // arg 1: B_layout
// arg 2: A_layout // arg 2: A_dtype
// arg 3: B_layout // arg 3: B_dtype
// arg 4: A_dtype // arg 4: C_dtype
// arg 5: B_dtype // arg 5: multiplicand_a
// arg 6: C_dtype // arg 6: multiplicand_a offset
// arg 7: multiplicand_a // arg 7: multiplicand_b descriptor
// arg 8: multiplicand_b // arg 8: multiplicand_b offset
// arg 9: accumulator // arg 9: accumulator
// arg 10: saturate // arg 10: accumulator offset
ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args; // arg 11: scale_out
// arg 12: scale_in_a
// arg 13: scale_in_b
ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args;
std::string shape = Downcast<StringImm>(op->args[0])->value; std::string shape = Downcast<StringImm>(op->args[0])->value;
bool A_layout = Downcast<Bool>(op->args[1])->value; bool b_is_k_major = Downcast<Bool>(op->args[1])->value;
bool B_layout = Downcast<Bool>(op->args[2])->value; std::string A_dtype = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value; std::string B_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value; std::string C_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value; std::string a_ref = this->PrintExpr(op->args[5]);
std::string a_ref = this->PrintExpr(op->args[6]); std::string A_offset = this->PrintExpr(op->args[6]);
std::string A_offset = this->PrintExpr(op->args[7]); std::string b_desc = this->PrintExpr(op->args[7]);
std::string b_desc = this->PrintExpr(op->args[8]); std::string B_offset = this->PrintExpr(op->args[8]);
std::string B_offset = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]); std::string c_offset = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]); bool scale_out = Downcast<Bool>(op->args[11])->value;
bool scale_out = Downcast<Bool>(op->args[12])->value; bool scale_in_a = Downcast<Bool>(op->args[12])->value;
bool scale_in_a = Downcast<Bool>(op->args[13])->value; bool scale_in_b = Downcast<Bool>(op->args[13])->value;
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
const bool a_is_shared = false; need_wgmma_instruction_h_ = true;
this->PrintIndent(); this->PrintIndent();
std::string asm_code = PrintWGMMAAssembly( std::string wgmma_call =
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset, "tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b, "(tnspB), (scaleA), (scaleB)>(reinterpret_cast<const "
a_is_shared, "", "", "", false); "uint32_t*>((A_ptr) + (A_offset)), "
this->stream << asm_code; "uint64_t((desc_b) + (B_offset)), "
"reinterpret_cast<uint32_t*>((C_ptr) + (C_offset)), "
"(scale_out));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(tnspA)", "false");
replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
wgmma_call = replacer.rewrite(wgmma_call);
this->stream << wgmma_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {
ICHECK_EQ(op->args.size(), 14U)
<< "ptx_tcgen05_mma_ss args is " << op->args;
std::string C_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_desc = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
bool enable_ws = Downcast<Bool>(op->args[13])->value;
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::(tcgen05_name)<(CType)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(desc_a)", a_desc);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(tcgen05_name)",
enable_ws ? "tcgen05mma_ws_ss" : "tcgen05mma_ss");
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) {
// TS: A from TMEM, B from SMEM (desc)
ICHECK_EQ(op->args.size(), 13U)
<< "ptx_tcgen05_mma_ts args is " << op->args;
std::string kind_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_ref = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::tcgen05mma_ts<(CType)>( (*reinterpret_cast<uint32_t*>((A))) + "
"(A_offset), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_enum));
replacer.register_rule("(A)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument";
need_tcgen05_common_h_ = true;
this->PrintIndent();
this->stream << "tl::tcgen05_mma_arrive(" << this->PrintExpr(op->args[0])
<< ");\n";
} else if (op->op.same_as(builtin::ptx_ldmatrix())) { } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not. // arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load. // arg 1: number of matrices to load.
...@@ -2214,19 +2404,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -2214,19 +2404,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << ")"; os << ")";
} else if (op->op.same_as(tl::tl_shuffle_elect())) { } else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) { } else if (op->op.same_as(tl::initialize_wgmma_descriptor())) {
ICHECK(op->args.size() == 5) ICHECK(op->args.size() == 5)
<< "tl_initialize_descriptor expects 5 arguments but got " << "tl_initialize_wgmma_descriptor expects 5 arguments but got "
<< op->args.size(); << op->args.size();
auto descriptor = op->args[0]; auto descriptor = op->args[0];
auto start_address = op->args[1]; auto start_address = op->args[1];
auto layout_type = op->args[2]; auto layout_type = op->args[2];
auto leading_byte_offset = op->args[3]; auto leading_byte_offset = op->args[3];
auto stride_byte_offset = op->args[4]; auto stride_byte_offset = op->args[4];
os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", " os << "tl::initialize_wgmma_descriptor<" << PrintExpr(layout_type) << ", "
<< PrintExpr(leading_byte_offset) << ", " << PrintExpr(leading_byte_offset) << ", "
<< PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", " << PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ")"; << PrintExpr(start_address) << ")";
} else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) {
ICHECK(op->args.size() == 7)
<< "tl_initialize_tcgen05_descriptor expects 7 arguments but got "
<< op->args.size();
auto descriptor = op->args[0];
auto start_address = op->args[1];
auto leading_byte_offset = op->args[2];
auto stride_byte_offset = op->args[3];
auto base_offset = op->args[4];
auto leading_abs = op->args[5];
auto swizzle_mode = op->args[6];
os << "tl::initialize_tcgen05_descriptor(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ", " << PrintExpr(leading_byte_offset)
<< ", " << PrintExpr(stride_byte_offset) << ", "
<< PrintExpr(base_offset) << ", " << PrintExpr(leading_abs) << ", "
<< PrintExpr(swizzle_mode) << ")";
} else if (op->op.same_as(tl::increase_descriptor_offset())) { } else if (op->op.same_as(tl::increase_descriptor_offset())) {
ICHECK(op->args.size() == 2) ICHECK(op->args.size() == 2)
<< "tl_increase_descriptor_offset expects 2 arguments but got " << "tl_increase_descriptor_offset expects 2 arguments but got "
...@@ -2377,8 +2583,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -2377,8 +2583,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
<< "Accumulator only support half, float and int type for now"; << "Accumulator only support half, float and int type for now";
} }
PrintWmmaScope(scope, op->dtype, buffer, stream); PrintWmmaScope(scope, op->dtype, buffer, stream);
} else if (scope == "local.descriptor") { } else if (scope == "local.descriptor.wgmma") {
stream << "tl::GmmaDescriptor " << vid << ";\n"; stream << "tl::GmmaDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_smem") {
stream << "tl::Tcgen05SMemDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_instr") {
stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n";
} else { } else {
PrintStorageScope(scope, stream); PrintStorageScope(scope, stream);
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
...@@ -2420,7 +2630,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -2420,7 +2630,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
init = user_init; init = user_init;
} }
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n"; stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
} else if (scope != "local.descriptor") { } else if (scope.find("local.descriptor") != 0) {
ICHECK(false) << "Unsupported scope: " << scope; ICHECK(false) << "Unsupported scope: " << scope;
} }
} }
......
...@@ -108,6 +108,14 @@ private: ...@@ -108,6 +108,14 @@ private:
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
// whether need mma.h // whether need mma.h
bool need_mma_h_{false}; bool need_mma_h_{false};
// whether need tl mma instruction header
bool need_mma_instruction_h_{false};
// whether need tl wgmma instruction header
bool need_wgmma_instruction_h_{false};
// whether need tl tcgen05mma instruction header
bool need_tcgen05mma_instruction_h_{false};
// whether need tcgen_05 common header
bool need_tcgen05_common_h_{false};
// whether need cast_smem_ptr_to_int helper function // whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false}; bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h // whether need cooperative_groups.h
......
...@@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) { ...@@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) {
return DataType::kInt64; return DataType::kInt64;
} else if (str == "uint64" || str == ".u64") { } else if (str == "uint64" || str == ".u64") {
return DataType::kUInt64; return DataType::kUInt64;
} else if (str == "e4m3" || str == ".e4m3") { } else if (str == "float8_e4m3" || str == "e4m3" || str == ".e4m3") {
return DataType::kFloat8_e4m3; return DataType::kFloat8_e4m3;
} else if (str == "e5m2" || str == ".e5m2") { } else if (str == "float8_e5m2" || str == "e5m2" || str == ".e5m2") {
return DataType::kFloat8_e5m2; return DataType::kFloat8_e5m2;
} else if (str == "float16" || str == "fp16" || str == ".f16") { } else if (str == "float16" || str == "fp16" || str == ".f16") {
return DataType::kFloat16; return DataType::kFloat16;
...@@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) { ...@@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) {
return predicated_asm_code; return predicated_asm_code;
} }
std::string GetMMARegisterType(const ptx::DataType &dtype) {
switch (dtype) {
case ptx::DataType::kInt32:
return "unsigned";
case ptx::DataType::kUInt32:
return "unsigned";
case ptx::DataType::kFloat32:
return "float";
case ptx::DataType::kFloat64:
return "double";
default:
return "unsigned";
}
}
} // namespace codegen } // namespace codegen
} // namespace tvm::tl } // namespace tvm::tl
...@@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, ...@@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
*/ */
std::string PrintWaitBarrierAsm(const std::string &barrier); std::string PrintWaitBarrierAsm(const std::string &barrier);
/*!
* \brief Return the register-level C++ type used by MMA fragments.
*/
std::string GetMMARegisterType(const ptx::DataType &dtype);
} // namespace codegen } // namespace codegen
} // namespace tvm::tl } // namespace tvm::tl
......
...@@ -288,6 +288,138 @@ union GmmaDescriptor { ...@@ -288,6 +288,138 @@ union GmmaDescriptor {
} }
}; };
union Tcgen05SMemDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(uint64_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
// Bitfield implementation avoids the need for shifts in assignment
struct {
// start_address, bit [0,14), 4LSB not included
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
uint16_t stride_byte_offset_ : 14,
version_ : 2; // 14 bits [0,14), 2 bits [14,16)
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1,
: 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0,
// SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4,
// SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5,
// N/A = 7
uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8)
} bitfield;
// Separate the field, as we may only update one part of desc
struct {
uint32_t lo;
uint32_t hi;
} words;
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor
operator+(const T &offset) const {
Tcgen05SMemDescriptor ret;
// Address addition is in units of 16 bytes (4 LSB not encoded)
ret.reg32_[0] = reg32_[0] + (uint32_t(offset) >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}
};
//
// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout)
//
union Tcgen05InstrDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(uint32_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint32_t desc_;
uint16_t reg16_[2];
// Bitfield implementation mirrors cute::UMMA::InstrDescriptor
struct {
// bit [ 0, 2) : Sparse meta data id2
uint16_t sparse_id2_ : 2,
// bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for
// F32F16/S8/MXF8F6F4
sparse_flag_ : 1,
// bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8
saturate_ : 1,
// bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
c_format_ : 2,
// padding
: 1,
// bit [ 7,10) : see UMMA format encoding
a_format_ : 3,
// bit [10,13) : see UMMA format encoding
b_format_ : 3,
// bit [13,14) : 0 = no negate. 1 = negate
a_negate_ : 1,
// bit [14,15) : 0 = no negate. 1 = negate
b_negate_ : 1,
// bit [15,16) : 0 = K-major. 1 = MN-major
a_major_ : 1;
// Upper 16 bits
uint16_t b_major_ : 1, // bit [16,17)
n_dim_ : 6, // bit [17,23) : 3 LSBs not included
: 1, // padding
m_dim_ : 5, // bit [24,29) : 4 LSBs not included
: 1, // padding
max_shift_ : 2; // bit [30,32)
} bitfield;
// Decay to a uint32_t
CUTE_HOST_DEVICE constexpr explicit operator uint32_t() const noexcept {
return desc_;
}
};
// Any // Any
template <typename T> TL_DEVICE bool Any(T *a, int size) { template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
...@@ -326,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() { ...@@ -326,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() {
template <int layout_type = 0, int leading_byte_offset = 0, template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T> int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor,
T *start_address) { T *start_address) {
descriptor.bitfield.start_address_ = descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4; cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type; descriptor.bitfield.layout_type_ = layout_type;
...@@ -336,6 +468,23 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor, ...@@ -336,6 +468,23 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset; descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
} }
template <typename T>
TL_DEVICE void
initialize_tcgen05_descriptor(Tcgen05SMemDescriptor &descriptor,
T *start_address, int leading_byte_offset,
int stride_byte_offset, int base_offset,
bool leading_is_absolute, int swizzle_mode) {
descriptor.bitfield.start_address_ =
static_cast<uint16_t>(cast_smem_ptr_to_uint(start_address) >> 4);
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
descriptor.bitfield.version_ = 1;
descriptor.bitfield.base_offset_ = base_offset & 0x7;
descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0;
descriptor.bitfield.layout_type_ = swizzle_mode & 0x7;
}
template <typename T> template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor, TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) { T offset) {
......
#pragma once
#include "../common.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
template <class Impl> struct MmaImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c) {
call_fma_impl<Impl>(d, a, b, c,
std::make_index_sequence<MmaImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kCRegs>{});
}
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate>
struct MmaDispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync: unsupported configuration");
}
};
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
NValue, KValue, TransAValue, TransBValue, \
SaturateValue, ImplType) \
template <> \
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, MValue, NValue, KValue, \
TransAValue, TransBValue, SaturateValue> { \
using Impl = ImplType; \
using Traits = MmaImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma<Impl>(d, a, b, c); \
} \
};
// FP16 inputs (TN layout: A row-major, B column-major)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F16F16F16F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32F16F16F32_TN)
// BF16 inputs
TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32BF16BF16F32_TN)
// INT8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S8S8S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U8U8S32_TN)
// INT4 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S4S4S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U4U4S32_TN)
// FP8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN)
#undef TL_DEFINE_MMA_DISPATCHER
} // namespace detail
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate = false>
TL_DEVICE void mma_sync(
typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA, TransB,
Saturate>::CRegType *c,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::ARegType *a,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::BRegType *b) {
using Dispatcher = detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync: unsupported configuration");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ss: unsupported accumulator type");
}
// TS variants: A from TMEM, B from SMEM (desc)
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ts(uint32_t const & /*tmem_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ts: unsupported accumulator type");
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kBFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat16>(tmem_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kTensorFloat32>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kInt8>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e4m3>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, "
"{%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e5m2>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat8_e4m3>(tmem_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
// idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for
// SS Load TMEM base from shared memory slot handled by caller
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat8_e4m3>(desc_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx
// Generic declaration falls back to static assert
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ws_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ws_ss: unsupported accumulator type");
}
// F16/BF16 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// INT8 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// FP8 ws (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3);
}
} // namespace tl
#pragma once #pragma once
#include "../common.h" #include "../common.h"
#include "cute/arch/mma_sm90_gmma.hpp" #include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <type_traits>
#include <utility>
namespace tl { namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false; template <class> inline constexpr bool always_false_v = false;
#endif
// 主类模板 - 移除默认参数,因为特化不能有默认参数 namespace detail {
template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
printf("DEBUG: WgmmaSSImpl fallback - A_type=%d (kFloat16=%d), B_type=%d, "
"C_type=%d, M=%d, N=%d, K=%d, tnspA=%d, tnspB=%d, scaleA=%d, "
"scaleB=%d\n",
(int)A_type, (int)DataType::kFloat16, (int)B_type, (int)C_type, M, N,
K, (int)tnspA, (int)tnspB, scaleA, scaleB);
// 暂时注释掉 static_assert 来看调试输出
// static_assert(always_false_v<decltype(c)>,
// "wgmma_ss: No specialization available for given template
// parameters!");
};
};
// ================================= F16 x F16 -> F16
// =================================
// M64N8K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 8, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K16 F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 16, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N32K16 F16 template <bool IsMnMajor> struct MajorValue {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static constexpr auto value =
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, IsMnMajor ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K;
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
}; };
// M64N64K16 F16 template <int Scale> struct ScaleInValue {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(Scale == 1 || Scale == -1,
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, "tl::wgmma requires scale factors of +1 or -1.");
64, 64, 16, tnspA, tnspB, scaleA, scaleB> { static constexpr auto value = Scale == 1 ? cute::SM90::GMMA::ScaleIn::One
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, : cute::SM90::GMMA::ScaleIn::Neg;
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15},"
" %16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
}; };
// M64N96K16 F16 template <int Scale>
template <bool tnspA, bool tnspB, int scaleA, int scaleB> inline constexpr bool IsValidScale = (Scale == 1 || Scale == -1);
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16,
64, 96, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %26, 0;\n"
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23}, "
"%24, %25, p, %27, %28, %29, %30;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N128K16 F16 template <class Impl> struct CallWgmmaSS {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> using CReg = std::remove_extent_t<typename Impl::CRegisters>;
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
64, 128, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(sizeof(CReg) == sizeof(uint32_t),
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_ss expects 32-bit accumulator registers.");
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N192K16 F16 template <size_t... Idx>
template <bool tnspA, bool tnspB, int scaleA, int scaleB> TL_DEVICE static void Run(uint64_t desc_a, uint64_t desc_b, CReg *c,
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, cute::SM90::GMMA::ScaleOut scale,
64, 192, 16, tnspA, tnspB, scaleA, scaleB> { std::index_sequence<Idx...>) {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, Impl::fma(desc_a, desc_b, c[Idx]..., scale);
bool scale_out) {
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %50, 0;\n"
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47}, "
"%48, %49, p, %51, %52, %53, %54;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
};
// M64N256K16 F16 TL_DEVICE static void exec(uint64_t desc_a, uint64_t desc_b, uint32_t *c_raw,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> bool scale_out) {
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat16, auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
64, 256, 16, tnspA, tnspB, scaleA, scaleB> { : cute::SM90::GMMA::ScaleOut::Zero;
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, auto c = reinterpret_cast<CReg *>(c_raw);
bool scale_out) { Run(desc_a, desc_b, c, scale, std::make_index_sequence<kCRegs>{});
asm volatile(
"{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %66, 0;\n"
"wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31, "
"%32, %33, %34, %35, %36, %37, %38, %39, "
"%40, %41, %42, %43, %44, %45, %46, %47, "
"%48, %49, %50, %51, %52, %53, %54, %55, "
"%56, %57, %58, %59, %60, %61, %62, %63}, "
"%64, %65, p, %67, %68, %69, %70;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]), "+r"(c[14]),
"+r"(c[15]), "+r"(c[16]), "+r"(c[17]), "+r"(c[18]), "+r"(c[19]),
"+r"(c[20]), "+r"(c[21]), "+r"(c[22]), "+r"(c[23]), "+r"(c[24]),
"+r"(c[25]), "+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31]), "+r"(c[32]), "+r"(c[33]), "+r"(c[34]),
"+r"(c[35]), "+r"(c[36]), "+r"(c[37]), "+r"(c[38]), "+r"(c[39]),
"+r"(c[40]), "+r"(c[41]), "+r"(c[42]), "+r"(c[43]), "+r"(c[44]),
"+r"(c[45]), "+r"(c[46]), "+r"(c[47]), "+r"(c[48]), "+r"(c[49]),
"+r"(c[50]), "+r"(c[51]), "+r"(c[52]), "+r"(c[53]), "+r"(c[54]),
"+r"(c[55]), "+r"(c[56]), "+r"(c[57]), "+r"(c[58]), "+r"(c[59]),
"+r"(c[60]), "+r"(c[61]), "+r"(c[62]), "+r"(c[63])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
}; };
// ================================= F16 x F16 -> F32 template <class Impl> struct CallWgmmaRS {
// ================================= using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
// M64N8K16 F16->F32 static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, static_assert(sizeof(AReg) == sizeof(uint32_t),
64, 8, 16, tnspA, tnspB, scaleA, scaleB> { "tl::wgmma_rs expects 32-bit register operands for A.");
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(sizeof(CReg) == sizeof(uint32_t) ||
bool scale_out) { sizeof(CReg) == sizeof(float),
asm volatile("{\n" "tl::wgmma_rs expects 32-bit accumulator registers.");
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n" template <size_t... AIdx, size_t... CIdx>
"wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " TL_DEVICE static void
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" Run(const AReg *a, uint64_t desc_b, CReg *c, cute::SM90::GMMA::ScaleOut scale,
"}\n" std::index_sequence<AIdx...>, std::index_sequence<CIdx...>) {
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) Impl::fma(a[AIdx]..., desc_b, c[CIdx]..., scale);
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
};
// M64N16K16 F16->F32 TL_DEVICE static void exec(const uint32_t *a_raw, uint64_t desc_b,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> uint32_t *c_raw, bool scale_out) {
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, auto scale = scale_out ? cute::SM90::GMMA::ScaleOut::One
64, 16, 16, tnspA, tnspB, scaleA, scaleB> { : cute::SM90::GMMA::ScaleOut::Zero;
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, auto a = reinterpret_cast<const AReg *>(a_raw);
bool scale_out) { auto c = reinterpret_cast<CReg *>(c_raw);
asm volatile( Run(a, desc_b, c, scale, std::make_index_sequence<kARegs>{},
"{\n" std::make_index_sequence<kCRegs>{});
".reg .pred p;\n"
"setp.ne.b32 p, %10, 0;\n"
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
} }
}; };
// M64N32K16 F16->F32 } // namespace detail
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32,
64, 32, 16, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %18, 0;\n"
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15}, "
"%16, %17, p, %19, %20, %21, %22;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N64K16 F16->F32 template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
template <bool tnspA, bool tnspB, int scaleA, int scaleB> int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat16, DataType::kFloat16, DataType::kFloat32, struct WgmmaSSImpl {
64, 64, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_ss: invalid scaleA");
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_ss: invalid scaleB");
bool scale_out) { TL_DEVICE static void execute(uint64_t, uint64_t, uint32_t *, bool) {
asm volatile("{\n" static_assert(always_false_v<std::integral_constant<int, M>>,
".reg .pred p;\n" "tl::wgmma_ss: unsupported configuration");
"setp.ne.b32 p, %34, 0;\n"
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
"{%0, %1, %2, %3, %4, %5, %6, %7, "
"%8, %9, %10, %11, %12, %13, %14, %15, "
"%16, %17, %18, %19, %20, %21, %22, %23, "
"%24, %25, %26, %27, %28, %29, %30, %31}, "
"%32, %33, p, %35, %36, %37, %38;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]),
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]), "+r"(c[8]), "+r"(c[9]),
"+r"(c[10]), "+r"(c[11]), "+r"(c[12]), "+r"(c[13]),
"+r"(c[14]), "+r"(c[15]), "+r"(c[16]), "+r"(c[17]),
"+r"(c[18]), "+r"(c[19]), "+r"(c[20]), "+r"(c[21]),
"+r"(c[22]), "+r"(c[23]), "+r"(c[24]), "+r"(c[25]),
"+r"(c[26]), "+r"(c[27]), "+r"(c[28]), "+r"(c[29]),
"+r"(c[30]), "+r"(c[31])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
}; };
// ================================= BF16 x BF16 -> F32 template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
// ================================= int K, bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaRSImpl {
// M64N8K16 BF16->F32 static_assert(detail::IsValidScale<scaleA>, "tl::wgmma_rs: invalid scaleA");
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(detail::IsValidScale<scaleB>, "tl::wgmma_rs: invalid scaleB");
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, TL_DEVICE static void execute(const uint32_t *, uint64_t, uint32_t *, bool) {
64, 8, 16, tnspA, tnspB, scaleA, scaleB> { static_assert(always_false_v<std::integral_constant<int, M>>,
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_rs: unsupported configuration");
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %6, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 "
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
} }
}; };
// M64N16K16 BF16->F32 #define TL_WGMMA_DEFINE_SS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::kBFloat16, DataType::kBFloat16, DataType::kFloat32, struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
64, 16, 16, tnspA, tnspB, scaleA, scaleB> { K, tnspA, tnspB, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_ss: invalid scaleA"); \
asm volatile( static_assert(detail::IsValidScale<scaleB>, \
"{\n" "tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;\n" using Impl = \
"setp.ne.b32 p, %10, 0;\n" cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " detail::MajorValue<tnspB>::value, \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" detail::ScaleInValue<scaleA>::value, \
"}\n" detail::ScaleInValue<scaleB>::value>; \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), } \
"n"(int32_t(tnspB))); };
}
};
// ================================= TF32 x TF32 -> F32 #define TL_WGMMA_DEFINE_SS_TN(AType, BType, CType, M, N, K, ImplName) \
// ================================= template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K8 TF32->F32 K, false, false, scaleA, scaleB> { \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(detail::IsValidScale<scaleA>, \
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, "tl::wgmma_ss: invalid scaleA"); \
DataType::kFloat32, 64, 8, 8, tnspA, tnspB, scaleA, scaleB> { static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_ss: invalid scaleB"); \
bool scale_out) { using Impl = \
asm volatile("{\n" cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
".reg .pred p;\n" detail::ScaleInValue<scaleB>::value>; \
"setp.ne.b32 p, %6, 0;\n" TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " uint32_t *c, bool scale_out) { \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"}\n" } \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) };
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N16K8 TF32->F32 #define TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> ImplName) \
struct WgmmaSSImpl<DataType::kTensorFloat32, DataType::kTensorFloat32, template <int scaleA, int scaleB> \
DataType::kFloat32, 64, 16, 8, tnspA, tnspB, scaleA, struct WgmmaSSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
scaleB> { K, false, false, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_ss: invalid scaleA"); \
asm volatile( static_assert(detail::IsValidScale<scaleB>, \
"{\n" "tl::wgmma_ss: invalid scaleB"); \
".reg .pred p;\n" static_assert(scaleA == 1 && scaleB == 1, \
"setp.ne.b32 p, %10, 0;\n" "tl::wgmma_ss: only +1 scaling supported for this WGMMA"); \
"wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " using Impl = cute::SM90::GMMA::ImplName; \
"{%0, %1, %2, %3, %4, %5, %6, %7}, %8, %9, p, %11, %12, %13, %14;\n" TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, \
"}\n" uint32_t *c, bool scale_out) { \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]), "+r"(c[4]), detail::CallWgmmaSS<Impl>::exec(desc_a, desc_b, c, scale_out); \
"+r"(c[5]), "+r"(c[6]), "+r"(c[7]) } \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), };
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)),
"n"(int32_t(tnspB)));
}
};
// ================================= INT8 x INT8 -> INT32 #define TL_WGMMA_DEFINE_RS_GENERAL(AType, BType, CType, M, N, K, ImplName) \
// ================================= template <bool tnspA, bool tnspB, int scaleA, int scaleB> \
struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
// M64N8K32 S8->S32 K, tnspA, tnspB, scaleA, scaleB> { \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> static_assert(!tnspA, "tl::wgmma_rs: operand A must be K-major"); \
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 8, static_assert(detail::IsValidScale<scaleA>, \
32, tnspA, tnspB, scaleA, scaleB> { "tl::wgmma_rs: invalid scaleA"); \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleB>, \
bool scale_out) { "tl::wgmma_rs: invalid scaleB"); \
asm volatile("{\n" using Impl = \
".reg .pred p;\n" cute::SM90::GMMA::ImplName<detail::MajorValue<tnspA>::value, \
"setp.ne.b32 p, %4, 0;\n" detail::MajorValue<tnspB>::value, \
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " detail::ScaleInValue<scaleA>::value, \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" detail::ScaleInValue<scaleB>::value>; \
"}\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
: "+r"(c[0]), "+r"(c[1]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), } \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); };
}
};
// M64N16K32 S8->S32 #define TL_WGMMA_DEFINE_RS_TN(AType, BType, CType, M, N, K, ImplName) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> template <int scaleA, int scaleB> \
struct WgmmaSSImpl<DataType::kInt8, DataType::kInt8, DataType::kInt32, 64, 16, struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
32, tnspA, tnspB, scaleA, scaleB> { K, false, false, scaleA, scaleB> { \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, static_assert(detail::IsValidScale<scaleA>, \
bool scale_out) { "tl::wgmma_rs: invalid scaleA"); \
asm volatile("{\n" static_assert(detail::IsValidScale<scaleB>, \
".reg .pred p;\n" "tl::wgmma_rs: invalid scaleB"); \
"setp.ne.b32 p, %6, 0;\n" using Impl = \
"wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " cute::SM90::GMMA::ImplName<detail::ScaleInValue<scaleA>::value, \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" detail::ScaleInValue<scaleB>::value>; \
"}\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) uint32_t *c, bool scale_out) { \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), } \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); };
}
};
// ================================= FP8 x FP8 -> F16/F32 #define TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(AType, BType, CType, M, N, K, \
// ================================= ImplName) \
template <int scaleA, int scaleB> \
// M64N8K32 E4M3->F16 struct WgmmaRSImpl<DataType::AType, DataType::BType, DataType::CType, M, N, \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> K, false, false, scaleA, scaleB> { \
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, static_assert(detail::IsValidScale<scaleA>, \
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA, "tl::wgmma_rs: invalid scaleA"); \
scaleB> { static_assert(detail::IsValidScale<scaleB>, \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, "tl::wgmma_rs: invalid scaleB"); \
bool scale_out) { static_assert(scaleA == 1 && scaleB == 1, \
asm volatile("{\n" "tl::wgmma_rs: only +1 scaling supported for this WGMMA"); \
".reg .pred p;\n" using Impl = cute::SM90::GMMA::ImplName; \
"setp.ne.b32 p, %4, 0;\n" TL_DEVICE static void execute(const uint32_t *a, uint64_t desc_b, \
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " uint32_t *c, bool scale_out) { \
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n" detail::CallWgmmaRS<Impl>::exec(a, desc_b, c, scale_out); \
"}\n" } \
: "+r"(c[0]), "+r"(c[1]) };
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// M64N8K32 E4M3->F32 #define TL_WGMMA_FOREACH_N_FLOAT_MUL8(OP) \
template <bool tnspA, bool tnspB, int scaleA, int scaleB> OP(8) \
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e4m3, OP(16) \
DataType::kFloat32, 64, 8, 32, tnspA, tnspB, scaleA, OP(24) \
scaleB> { OP(32) \
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c, OP(40) \
bool scale_out) { OP(48) \
asm volatile("{\n" OP(56) \
".reg .pred p;\n" OP(64) \
"setp.ne.b32 p, %6, 0;\n" OP(72) \
"wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " OP(80) \
"{%0, %1, %2, %3}, %4, %5, p, %7, %8, %9, %10;\n" OP(88) \
"}\n" OP(96) \
: "+r"(c[0]), "+r"(c[1]), "+r"(c[2]), "+r"(c[3]) OP(104) \
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)), OP(112) \
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)), OP(120) \
"n"(int32_t(tnspA)), "n"(int32_t(tnspB))); OP(128) \
} OP(136) \
}; OP(144) \
OP(152) \
OP(160) \
OP(168) \
OP(176) \
OP(184) \
OP(192) \
OP(200) \
OP(208) \
OP(216) \
OP(224) \
OP(232) \
OP(240) \
OP(248) \
OP(256)
#define TL_WGMMA_FOREACH_N_INT32_MUL8(OP) \
OP(8) \
OP(16) \
OP(24) \
OP(32) \
OP(48) \
OP(64) \
OP(80) \
OP(96) \
OP(112) \
OP(128) \
OP(144) \
OP(160) \
OP(176) \
OP(192) \
OP(208) \
OP(224) \
OP(240) \
OP(256)
#define TL_WGMMA_DEFINE_F16_F16_F16_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_SS)
#define TL_WGMMA_DEFINE_F16_F16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_SS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_SS(N) \
TL_WGMMA_DEFINE_SS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_SS)
#define TL_WGMMA_DEFINE_F32_TF32_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_SS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_SS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_SS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN(N) \
TL_WGMMA_DEFINE_SS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_SS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_SS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_SS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN);
#define TL_WGMMA_DEFINE_F16_F16_F16_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat16, 64, N, 16, \
MMA_64x##N##x16_F16F16F16_RS)
#define TL_WGMMA_DEFINE_F16_F16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kFloat16, kFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32F16F16_RS)
#define TL_WGMMA_DEFINE_BF16_BF16_F32_RS(N) \
TL_WGMMA_DEFINE_RS_GENERAL(kBFloat16, kBFloat16, kFloat32, 64, N, 16, \
MMA_64x##N##x16_F32BF16BF16_RS)
#define TL_WGMMA_DEFINE_F32_TF32_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kTensorFloat32, kTensorFloat32, kFloat32, 64, N, 8, \
MMA_64x##N##x8_F32TF32TF32_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_S8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32S8U8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8S8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8S8_RS_TN)
#define TL_WGMMA_DEFINE_S32_U8U8_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE(kUInt8, kUInt8, kInt32, 64, N, 32, \
MMA_64x##N##x32_S32U8U8_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E4M3E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E4M3_RS_TN)
#define TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 64, N, 32, \
MMA_64x##N##x32_F16E5M2E5M2_RS_TN)
#define TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN(N) \
TL_WGMMA_DEFINE_RS_TN(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 64, N, 32, \
MMA_64x##N##x32_F32E5M2E5M2_RS_TN)
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F16_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_F16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_BF16_BF16_F32_RS);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_TF32_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_S8U8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8S8_RS_TN);
TL_WGMMA_FOREACH_N_INT32_MUL8(TL_WGMMA_DEFINE_S32_U8U8_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN);
TL_WGMMA_FOREACH_N_FLOAT_MUL8(TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN);
#undef TL_WGMMA_DEFINE_F16_F16_F16_SS
#undef TL_WGMMA_DEFINE_F16_F16_F32_SS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_SS
#undef TL_WGMMA_DEFINE_F32_TF32_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_SS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_SS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_SS_TN
#undef TL_WGMMA_DEFINE_F16_F16_F16_RS
#undef TL_WGMMA_DEFINE_F16_F16_F32_RS
#undef TL_WGMMA_DEFINE_BF16_BF16_F32_RS
#undef TL_WGMMA_DEFINE_F32_TF32_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_S8U8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8S8_RS_TN
#undef TL_WGMMA_DEFINE_S32_U8U8_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E4M3E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E4M3_RS_TN
#undef TL_WGMMA_DEFINE_F16_E5M2E5M2_RS_TN
#undef TL_WGMMA_DEFINE_F32_E5M2E5M2_RS_TN
#undef TL_WGMMA_FOREACH_N_FLOAT_MUL8
#undef TL_WGMMA_FOREACH_N_INT32_MUL8
#undef TL_WGMMA_DEFINE_SS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_SS_GENERAL
#undef TL_WGMMA_DEFINE_SS_TN
#undef TL_WGMMA_DEFINE_RS_TN_FIXED_SCALE
#undef TL_WGMMA_DEFINE_RS_GENERAL
#undef TL_WGMMA_DEFINE_RS_TN
// 函数模板委托给类模板
template <DataType A_type, DataType B_type, DataType C_type, int M, int N, template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1> int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
...@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c, ...@@ -519,129 +460,12 @@ TL_DEVICE void wgmma_ss(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
scaleB>::execute(desc_a, desc_b, c, scale_out); scaleB>::execute(desc_a, desc_b, c, scale_out);
} }
// ================================= Mixed Precision Support template <DataType A_type, DataType B_type, DataType C_type, int M, int N,
// ================================= int K, bool tnspA, bool tnspB, int scaleA = 1, int scaleB = 1>
TL_DEVICE void wgmma_rs(const uint32_t *a, uint64_t desc_b, uint32_t *c,
// Mixed precision: S8 x U8 -> S32 bool scale_out) {
template <bool tnspA, bool tnspB, int scaleA, int scaleB> WgmmaRSImpl<A_type, B_type, C_type, M, N, K, tnspA, tnspB, scaleA,
struct WgmmaSSImpl<DataType::kInt8, DataType::kUInt8, DataType::kInt32, 64, 8, scaleB>::execute(a, desc_b, c, scale_out);
32, tnspA, tnspB, scaleA, scaleB> { }
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x S8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision: U8 x U8 -> S32
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kUInt8, DataType::kUInt8, DataType::kInt32, 64, 8,
32, tnspA, tnspB, scaleA, scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E4M3 x E5M2 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e4m3, DataType::kFloat8_e5m2,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// Mixed precision FP8: E5M2 x E4M3 -> F16
template <bool tnspA, bool tnspB, int scaleA, int scaleB>
struct WgmmaSSImpl<DataType::kFloat8_e5m2, DataType::kFloat8_e4m3,
DataType::kFloat16, 64, 8, 32, tnspA, tnspB, scaleA,
scaleB> {
TL_DEVICE static void execute(uint64_t desc_a, uint64_t desc_b, uint32_t *c,
bool scale_out) {
asm volatile("{\n"
".reg .pred p;\n"
"setp.ne.b32 p, %4, 0;\n"
"wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 "
"{%0, %1}, %2, %3, p, %5, %6, %7, %8;\n"
"}\n"
: "+r"(c[0]), "+r"(c[1])
: "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_out)),
"n"(int32_t(scaleA)), "n"(int32_t(scaleB)),
"n"(int32_t(tnspA)), "n"(int32_t(tnspB)));
}
};
// ================================= Convenience Templates
// =================================
// Type trait to determine the number of output registers needed
template <DataType C_type, int M, int N> struct WgmmaOutputRegs {
static constexpr int value =
(M * N * (C_type == DataType::kFloat32 ? 32 : 16)) / (32 * 8);
};
// Type trait to get element size in bits
template <DataType dtype> struct ElementBits {
static constexpr int value =
(dtype == DataType::kFloat32 || dtype == DataType::kTensorFloat32 ||
dtype == DataType::kInt32)
? 32
: (dtype == DataType::kFloat16 || dtype == DataType::kBFloat16 ||
dtype == DataType::kInt16 || dtype == DataType::kUInt16)
? 16
: (dtype == DataType::kInt8 || dtype == DataType::kUInt8 ||
dtype == DataType::kFloat8_e4m3 || dtype == DataType::kFloat8_e5m2)
? 8
: (dtype == DataType::kInt4 || dtype == DataType::kUInt4) ? 4
: 8;
};
} // namespace tl } // namespace tl
\ No newline at end of file
...@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() { ...@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>(); cute::warpgroup_wait<NumMma>();
} }
TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
TL_DEVICE void warpgroup_fence_operand(float *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
// Template parameter: // Template parameter:
// thread_extent: the logical size (in number of threads) of each "group" // thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative // within which we want to elect exactly ONE representative
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#endif #endif
#include "common.h" #include "common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl { namespace tl {
...@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a, ...@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
} }
inline __device__ void amma_commit(uint64_t const *smem_ptr) { // Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier
TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr) {
uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr); uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::" if (cute::elect_one_sync()) {
"cluster.b64 [%0];" asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
: "cluster.b64 [%0];"
: "r"(bar_intptr)); :
: "r"(bar_intptr));
}
} }
} // namespace tl } // namespace tl
\ No newline at end of file
...@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) { ...@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return false; return false;
} }
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) || return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor()); call->op.same_as(initialize_wgmma_descriptor()) ||
call->op.same_as(initialize_tcgen05_descriptor());
} }
ProxyKind ProxyFromAttrValue(const ObjectRef &value) { ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment