Unverified Commit aef0a6bb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986)



* remove debug print

* pipeline fix

* use the correct buffer access scope

* rs support

* warp warpgroup_fence_operand

* fix

* fp8 dtype ptx enhance

* mma fix

* TCGEN05 Interface

* tcgen05 support

* rebase

* update

* Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors.

* lint fix

* Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module.

* wgmma fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c85bb3ac
Subproject commit 0f1ebab7b66732f34b652ce807c9ff0748cd473c
Subproject commit 1815c3e0b6ec4ead36370bbd1562025d8529017c
......@@ -17,7 +17,7 @@ The pass is conservative: unknown extern calls are treated as async so that the
### Timeline View
```
generic initialize_descriptor → generic shared-store → async wgmma
generic initialize_wgmma_descriptor → generic shared-store → async wgmma
│ │ │
└─ generic proxy ┴─ generic proxy ┴─ async proxy
│ fence inserted here ↑
......@@ -53,7 +53,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.ptx_wgmma_ss(
"float16",
......@@ -83,7 +83,7 @@ def kernel():
with T.Kernel(1):
desc = T.decl_buffer((1,), "uint64", scope="local.descriptor")
smem = T.decl_buffer((128,), "float16", scope="shared")
T.initialize_descriptor(desc, T.uint64(0), 2, 1, 32)
T.initialize_wgmma_descriptor(desc, T.uint64(0), 2, 1, 32)
smem[0] = T.float16(0)
T.fence_proxy_async()
T.ptx_wgmma_ss(
......
......@@ -546,6 +546,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return makeGemmABLayoutHopper(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_tcgen05mma_swizzled_layout",
[](int stride, int mat_continuous, int continuity, int element_size,
bool k_inner) {
return makeGemmABLayoutSm100(stride, mat_continuous, continuity,
element_size, k_inner);
})
.def("tl.make_full_bank_swizzled_layout",
[](int stride, int continuous, int element_size) {
return makeFullBankSwizzleLayout(stride, continuous, element_size);
......
......@@ -155,6 +155,16 @@ TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
.set_num_inputs(14)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
.set_num_inputs(13)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -219,6 +229,11 @@ TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -286,11 +301,16 @@ TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(initialize_descriptor)
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
.set_num_inputs(7)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -311,5 +331,10 @@ TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
......@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* scale_out, bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();
/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ss();
/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ts();
/*!
* \brief tvm intrinsics for initializing tensor memory
*
......@@ -361,6 +371,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();
/*!
* \brief Return the canonical lane index for the calling thread.
*
......@@ -494,7 +512,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();
TVM_DLL const Op &initialize_wgmma_descriptor();
/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL const Op &initialize_tcgen05_descriptor();
/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL const Op &tcgen05_mma_arrive();
/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
......
......@@ -12,77 +12,13 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tcgen5_meta.h"
namespace tvm {
namespace tl {
using namespace tir;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
......@@ -186,6 +122,8 @@ bool GemmNode::AllowWGMMA(int block_size, Target target) const {
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
LOG(INFO) << "allow_tcgen5mma: " << allow_tcgen5mma
<< ", allow_wgmma: " << allow_wgmma;
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
......@@ -195,7 +133,7 @@ GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
ICHECK(0) << "Unsupported target for gemm: " << target;
}
}
......@@ -578,6 +516,8 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
ICHECK(!trans_A)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
......
......@@ -13,6 +13,8 @@
#include "../support/ffi_aliases.h"
#include "../target/utils.h"
#include "tcgen5_meta.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
......@@ -49,7 +51,6 @@ using namespace tir;
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
......@@ -76,6 +77,19 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
if (args.size() > 16) {
node->mbarptr = args[16];
} else {
node->mbarptr = IntImm(DataType::UInt(32), 0);
}
if (args.size() > 18) {
node->C_coords = Array<PrimExpr>({args[17], args[18]});
} else if (args.size() > 17) {
node->C_coords = Array<PrimExpr>({args[17], IntImm(DataType::Int(32), 0)});
} else {
node->C_coords = Array<PrimExpr>(
{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)});
}
data_ = std::move(node);
}
......@@ -92,16 +106,37 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op);
}
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool GemmPyNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
bool GemmPyNode::AllowWGMMA(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
}
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
......@@ -290,5 +325,31 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.get_tcgen5_mma_meta",
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
Array<Integer> result;
if (success) {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
}
return result;
});
refl::GlobalDef().def(
"tl.get_tcgen5_instr_desc",
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
int scale_in_b) {
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
c_dtype, a_is_k_major, b_is_k_major,
scale_in_a, scale_in_b);
return Integer(static_cast<int64_t>(desc));
});
}
} // namespace tl
} // namespace tvm
......@@ -19,6 +19,8 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
......@@ -27,6 +29,8 @@ public:
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
PrimExpr mbarptr;
Array<PrimExpr> C_coords;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
......@@ -54,6 +58,8 @@ public:
.def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("mbarptr", &GemmPyNode::mbarptr)
.def_ro("C_coords", &GemmPyNode::C_coords)
.def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy);
......
#ifndef TVM_TL_OP_TCGEN5_META_H_
#define TVM_TL_OP_TCGEN5_META_H_
#include <cstdint>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <utility>
#include <vector>
namespace tvm {
namespace tl {
using runtime::DataType;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
inline uint32_t GetTCGEN5InstrDesc(int atom_m, int atom_n, int atom_k,
DataType ab_dtype, DataType c_dtype,
bool a_is_k_major, bool b_is_k_major,
int scale_in_a, int scale_in_b) {
ICHECK(atom_m % 16 == 0) << "atom_m must be divisible by 16";
ICHECK(atom_n % 8 == 0) << "atom_n must be divisible by 8";
ICHECK(atom_k == 16 || atom_k == 32)
<< "Unsupported atom_k for TCGEN5MMA descriptor: " << atom_k;
ICHECK(scale_in_a == 1 || scale_in_a == -1)
<< "scale_in_a must be +/-1 for TCGEN5MMA";
ICHECK(scale_in_b == 1 || scale_in_b == -1)
<< "scale_in_b must be +/-1 for TCGEN5MMA";
auto encode_dtype = [&](DataType dtype) -> uint32_t {
if (dtype.is_float16()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_bfloat16()) {
return static_cast<uint32_t>(1);
} else if (dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() ||
dtype.is_float8_e4m3()) {
return static_cast<uint32_t>(0);
} else if (dtype.is_float8_e5m2fnuz() || dtype.is_float8_e5m2()) {
return static_cast<uint32_t>(1);
}
LOG(FATAL) << "Unsupported dtype for TCGEN5MMA descriptor: " << dtype;
return 0u;
};
uint32_t a_format = encode_dtype(ab_dtype);
uint32_t b_format = a_format;
uint32_t c_format = 0;
if (c_dtype.is_float16()) {
c_format = 0;
} else if (c_dtype.is_float()) {
c_format = 1;
} else if (c_dtype.is_int()) {
c_format = 2;
} else {
LOG(FATAL) << "Unsupported accumulator dtype for TCGEN5MMA descriptor: "
<< c_dtype;
}
auto set_bits = [](uint32_t value, int start, int width) -> uint32_t {
uint32_t mask = (width == 32) ? 0xFFFFFFFFu : ((1u << width) - 1);
return (value & mask) << start;
};
uint32_t desc = 0;
desc |= set_bits(0, 0, 2); // sparse_id2
desc |= set_bits(0, 2, 1); // sparse_flag
desc |= set_bits(0, 3, 1); // saturate
desc |= set_bits(c_format, 4, 2);
desc |= set_bits(a_format, 7, 3);
desc |= set_bits(b_format, 10, 3);
uint32_t a_neg = (scale_in_a == -1) ? 1u : 0u;
uint32_t b_neg = (scale_in_b == -1) ? 1u : 0u;
desc |= set_bits(a_neg, 13, 1);
desc |= set_bits(b_neg, 14, 1);
uint32_t a_major = a_is_k_major ? 0u : 1u;
uint32_t b_major = b_is_k_major ? 0u : 1u;
desc |= set_bits(a_major, 15, 1);
desc |= set_bits(b_major, 16, 1);
uint32_t n_dim = static_cast<uint32_t>(atom_n >> 3);
uint32_t m_dim = static_cast<uint32_t>(atom_m >> 4);
desc |= set_bits(n_dim, 17, 6);
desc |= set_bits(0, 23, 1);
desc |= set_bits(m_dim, 24, 5);
desc |= set_bits(0, 29, 1);
uint32_t max_shift = 0u;
desc |= set_bits(max_shift, 30, 2);
return desc;
}
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_TCGEN5_META_H_
......@@ -260,6 +260,18 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
if (need_mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/mma.h>\n";
}
if (need_wgmma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/wgmma.h>\n";
}
if (need_tcgen05mma_instruction_h_) {
decl_stream << "#include <tl_templates/cuda/instruction/tcgen05mma.h>\n";
}
if (need_tcgen05_common_h_) {
decl_stream << "#include <tl_templates/cuda/tcgen_05.h>\n";
}
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
......@@ -1277,7 +1289,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope == "local.var" || scope == "local.descriptor") {
if (scope == "local.var" || scope.find("local.descriptor") == 0) {
os << vid;
return os.str();
}
......@@ -1597,6 +1609,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
<< ">();\n";
} else if (op->op.same_as(tl::warpgroup_fence_operand())) {
ICHECK_EQ(op->args.size(), 4U);
std::string dtype = Downcast<StringImm>(op->args[0])->value;
std::string data_ptr = this->PrintExpr(op->args[1]);
std::string offset = this->PrintExpr(op->args[2]);
std::string num_regs = this->PrintExpr(op->args[3]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(dtype);
std::string cast_type = "uint32_t";
if (dtype_enum == tl::codegen::ptx::DataType::kFloat32 ||
dtype_enum == tl::codegen::ptx::DataType::kTensorFloat32) {
cast_type = "float";
}
this->PrintIndent();
this->stream << "tl::warpgroup_fence_operand(reinterpret_cast<" << cast_type
<< "*>(" << data_ptr << " + " << offset << "), " << num_regs
<< ");\n";
} else if (op->op.same_as(tl::set_max_nreg())) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
......@@ -1708,14 +1736,43 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op =
op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_mma_instruction_h_ = true;
this->PrintIndent();
this->stream << asm_code;
std::string mma_call =
"tl::mma_sync<(AType), (BType), (CType), (M), (N), (K), (TransA), "
"(TransB)>(reinterpret_cast<(CRegType)*>((C_ptr) + (C_offset)), "
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", a_bias);
replacer.register_rule("(B_ptr)", b_ref);
replacer.register_rule("(B_offset)", b_bias);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_bias);
this->stream << replacer.rewrite(mma_call);
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
......@@ -1792,6 +1849,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
scale_in_b, a_is_shared, "", "", "", false);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
need_wgmma_instruction_h_ = true;
std::string wgmma_asm_code =
"tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
......@@ -1820,41 +1878,173 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
this->stream << wgmma_asm_code;
} else if (op->op.same_as(tl::ptx_wgmma_rs())) {
// arg 0: dtype
// arg 1: shape
// arg 2: A_layout
// arg 3: B_layout
// arg 4: A_dtype
// arg 5: B_dtype
// arg 6: C_dtype
// arg 7: multiplicand_a
// arg 8: multiplicand_b
// arg 0: shape
// arg 1: B_layout
// arg 2: A_dtype
// arg 3: B_dtype
// arg 4: C_dtype
// arg 5: multiplicand_a
// arg 6: multiplicand_a offset
// arg 7: multiplicand_b descriptor
// arg 8: multiplicand_b offset
// arg 9: accumulator
// arg 10: saturate
ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args;
// arg 10: accumulator offset
// arg 11: scale_out
// arg 12: scale_in_a
// arg 13: scale_in_b
ICHECK_EQ(op->args.size(), 14U) << "ptx_wgmma_rs args is " << op->args;
std::string shape = Downcast<StringImm>(op->args[0])->value;
bool A_layout = Downcast<Bool>(op->args[1])->value;
bool B_layout = Downcast<Bool>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string A_offset = this->PrintExpr(op->args[7]);
std::string b_desc = this->PrintExpr(op->args[8]);
std::string B_offset = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_offset = this->PrintExpr(op->args[11]);
bool scale_out = Downcast<Bool>(op->args[12])->value;
bool scale_in_a = Downcast<Bool>(op->args[13])->value;
bool scale_in_b = Downcast<Bool>(op->args[14])->value;
bool b_is_k_major = Downcast<Bool>(op->args[1])->value;
std::string A_dtype = Downcast<StringImm>(op->args[2])->value;
std::string B_dtype = Downcast<StringImm>(op->args[3])->value;
std::string C_dtype = Downcast<StringImm>(op->args[4])->value;
std::string a_ref = this->PrintExpr(op->args[5]);
std::string A_offset = this->PrintExpr(op->args[6]);
std::string b_desc = this->PrintExpr(op->args[7]);
std::string B_offset = this->PrintExpr(op->args[8]);
std::string c_ref = this->PrintExpr(op->args[9]);
std::string c_offset = this->PrintExpr(op->args[10]);
bool scale_out = Downcast<Bool>(op->args[11])->value;
bool scale_in_a = Downcast<Bool>(op->args[12])->value;
bool scale_in_b = Downcast<Bool>(op->args[13])->value;
auto dtype_a_enum = tl::codegen::ptx::DTypeFromString(A_dtype);
auto dtype_b_enum = tl::codegen::ptx::DTypeFromString(B_dtype);
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
const bool a_is_shared = false;
need_wgmma_instruction_h_ = true;
this->PrintIndent();
std::string asm_code = PrintWGMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset,
b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
a_is_shared, "", "", "", false);
this->stream << asm_code;
std::string wgmma_call =
"tl::wgmma_rs<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
"(tnspB), (scaleA), (scaleB)>(reinterpret_cast<const "
"uint32_t*>((A_ptr) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), "
"reinterpret_cast<uint32_t*>((C_ptr) + (C_offset)), "
"(scale_out));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(dtype_a_enum));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(dtype_b_enum));
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
replacer.register_rule("(N)", std::to_string(n));
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(tnspA)", "false");
replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
replacer.register_rule("(A_ptr)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C_ptr)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
wgmma_call = replacer.rewrite(wgmma_call);
this->stream << wgmma_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ss())) {
ICHECK_EQ(op->args.size(), 14U)
<< "ptx_tcgen05_mma_ss args is " << op->args;
std::string C_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_desc = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
bool enable_ws = Downcast<Bool>(op->args[13])->value;
auto dtype_c_enum = tl::codegen::ptx::DTypeFromString(C_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::(tcgen05_name)<(CType)>(uint64_t((desc_a) + (A_offset)), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(desc_a)", a_desc);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(tcgen05_name)",
enable_ws ? "tcgen05mma_ws_ss" : "tcgen05mma_ss");
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::ptx_tcgen05_mma_ts())) {
// TS: A from TMEM, B from SMEM (desc)
ICHECK_EQ(op->args.size(), 13U)
<< "ptx_tcgen05_mma_ts args is " << op->args;
std::string kind_dtype = Downcast<StringImm>(op->args[0])->value;
std::string a_ref = this->PrintExpr(op->args[1]);
std::string A_offset = this->PrintExpr(op->args[2]);
std::string b_desc = this->PrintExpr(op->args[3]);
std::string B_offset = this->PrintExpr(op->args[4]);
std::string c_ref = this->PrintExpr(op->args[5]);
std::string c_offset = this->PrintExpr(op->args[6]);
PrimExpr desc_expr = op->args[7];
std::string scale_out = this->PrintExpr(op->args[8]);
std::string mask0 = this->PrintExpr(op->args[9]);
std::string mask1 = this->PrintExpr(op->args[10]);
std::string mask2 = this->PrintExpr(op->args[11]);
std::string mask3 = this->PrintExpr(op->args[12]);
auto dtype_enum = tl::codegen::ptx::DTypeFromString(kind_dtype);
need_tcgen05mma_instruction_h_ = true;
this->PrintIndent();
std::string tcgen05_call =
"tl::tcgen05mma_ts<(CType)>( (*reinterpret_cast<uint32_t*>((A))) + "
"(A_offset), "
"uint64_t((desc_b) + (B_offset)), (*reinterpret_cast<uint32_t*>((C))) "
"+ (C_offset), "
"(scale_out), static_cast<uint32_t>((desc_val)), (mask0), (mask1), "
"(mask2), (mask3));\n";
tl::codegen::Replacer replacer;
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_enum));
replacer.register_rule("(A)", a_ref);
replacer.register_rule("(A_offset)", A_offset);
replacer.register_rule("(desc_b)", b_desc);
replacer.register_rule("(B_offset)", B_offset);
replacer.register_rule("(C)", c_ref);
replacer.register_rule("(C_offset)", c_offset);
replacer.register_rule("(scale_out)", scale_out);
replacer.register_rule("(desc_val)", this->PrintExpr(desc_expr));
replacer.register_rule("(mask0)", mask0);
replacer.register_rule("(mask1)", mask1);
replacer.register_rule("(mask2)", mask2);
replacer.register_rule("(mask3)", mask3);
tcgen05_call = replacer.rewrite(tcgen05_call);
this->stream << tcgen05_call;
} else if (op->op.same_as(tl::tcgen05_mma_arrive())) {
ICHECK_EQ(op->args.size(), 1U) << "tcgen05_mma_arrive expects 1 argument";
need_tcgen05_common_h_ = true;
this->PrintIndent();
this->stream << "tl::tcgen05_mma_arrive(" << this->PrintExpr(op->args[0])
<< ");\n";
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
......@@ -2214,19 +2404,35 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << ")";
} else if (op->op.same_as(tl::tl_shuffle_elect())) {
os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
} else if (op->op.same_as(tl::initialize_descriptor())) {
} else if (op->op.same_as(tl::initialize_wgmma_descriptor())) {
ICHECK(op->args.size() == 5)
<< "tl_initialize_descriptor expects 5 arguments but got "
<< "tl_initialize_wgmma_descriptor expects 5 arguments but got "
<< op->args.size();
auto descriptor = op->args[0];
auto start_address = op->args[1];
auto layout_type = op->args[2];
auto leading_byte_offset = op->args[3];
auto stride_byte_offset = op->args[4];
os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", "
os << "tl::initialize_wgmma_descriptor<" << PrintExpr(layout_type) << ", "
<< PrintExpr(leading_byte_offset) << ", "
<< PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ")";
} else if (op->op.same_as(tl::initialize_tcgen05_descriptor())) {
ICHECK(op->args.size() == 7)
<< "tl_initialize_tcgen05_descriptor expects 7 arguments but got "
<< op->args.size();
auto descriptor = op->args[0];
auto start_address = op->args[1];
auto leading_byte_offset = op->args[2];
auto stride_byte_offset = op->args[3];
auto base_offset = op->args[4];
auto leading_abs = op->args[5];
auto swizzle_mode = op->args[6];
os << "tl::initialize_tcgen05_descriptor(" << PrintExpr(descriptor) << ", "
<< PrintExpr(start_address) << ", " << PrintExpr(leading_byte_offset)
<< ", " << PrintExpr(stride_byte_offset) << ", "
<< PrintExpr(base_offset) << ", " << PrintExpr(leading_abs) << ", "
<< PrintExpr(swizzle_mode) << ")";
} else if (op->op.same_as(tl::increase_descriptor_offset())) {
ICHECK(op->args.size() == 2)
<< "tl_increase_descriptor_offset expects 2 arguments but got "
......@@ -2377,8 +2583,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
<< "Accumulator only support half, float and int type for now";
}
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else if (scope == "local.descriptor") {
} else if (scope == "local.descriptor.wgmma") {
stream << "tl::GmmaDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_smem") {
stream << "tl::Tcgen05SMemDescriptor " << vid << ";\n";
} else if (scope == "local.descriptor.tcgen05_instr") {
stream << "tl::Tcgen05InstrDescriptor " << vid << ";\n";
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
......@@ -2420,7 +2630,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
init = user_init;
}
stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
} else if (scope != "local.descriptor") {
} else if (scope.find("local.descriptor") != 0) {
ICHECK(false) << "Unsupported scope: " << scope;
}
}
......
......@@ -108,6 +108,14 @@ private:
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need tl mma instruction header
bool need_mma_instruction_h_{false};
// whether need tl wgmma instruction header
bool need_wgmma_instruction_h_{false};
// whether need tl tcgen05mma instruction header
bool need_tcgen05mma_instruction_h_{false};
// whether need tcgen_05 common header
bool need_tcgen05_common_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
......
......@@ -74,9 +74,9 @@ DataType DTypeFromString(const std::string str) {
return DataType::kInt64;
} else if (str == "uint64" || str == ".u64") {
return DataType::kUInt64;
} else if (str == "e4m3" || str == ".e4m3") {
} else if (str == "float8_e4m3" || str == "e4m3" || str == ".e4m3") {
return DataType::kFloat8_e4m3;
} else if (str == "e5m2" || str == ".e5m2") {
} else if (str == "float8_e5m2" || str == "e5m2" || str == ".e5m2") {
return DataType::kFloat8_e5m2;
} else if (str == "float16" || str == "fp16" || str == ".f16") {
return DataType::kFloat16;
......@@ -1529,5 +1529,20 @@ std::string PrintWaitBarrierAsm(const std::string &barrier) {
return predicated_asm_code;
}
std::string GetMMARegisterType(const ptx::DataType &dtype) {
switch (dtype) {
case ptx::DataType::kInt32:
return "unsigned";
case ptx::DataType::kUInt32:
return "unsigned";
case ptx::DataType::kFloat32:
return "float";
case ptx::DataType::kFloat64:
return "double";
default:
return "unsigned";
}
}
} // namespace codegen
} // namespace tvm::tl
......@@ -269,6 +269,11 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier,
*/
std::string PrintWaitBarrierAsm(const std::string &barrier);
/*!
* \brief Return the register-level C++ type used by MMA fragments.
*/
std::string GetMMARegisterType(const ptx::DataType &dtype);
} // namespace codegen
} // namespace tvm::tl
......
......@@ -288,6 +288,138 @@ union GmmaDescriptor {
}
};
union Tcgen05SMemDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(uint64_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor(
Tcgen05SMemDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor &
operator=(Tcgen05SMemDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
// Bitfield implementation avoids the need for shifts in assignment
struct {
// start_address, bit [0,14), 4LSB not included
uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// leading dimension byte offset, bit [16,30), 4LSB not included
uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
// stride dimension byte offset, bit [32,46), 4LSB not included
uint16_t stride_byte_offset_ : 14,
version_ : 2; // 14 bits [0,14), 2 bits [14,16)
// base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53).
uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1,
: 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused
// layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0,
// SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4,
// SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5,
// N/A = 7
uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8)
} bitfield;
// Separate the field, as we may only update one part of desc
struct {
uint32_t lo;
uint32_t hi;
} words;
CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
return desc_;
}
template <typename T>
CUTE_HOST_DEVICE constexpr Tcgen05SMemDescriptor
operator+(const T &offset) const {
Tcgen05SMemDescriptor ret;
// Address addition is in units of 16 bytes (4 LSB not encoded)
ret.reg32_[0] = reg32_[0] + (uint32_t(offset) >> 4);
ret.reg32_[1] = reg32_[1];
return ret;
}
};
//
// Tcgen05 instruction descriptor (wraps cute::UMMA::InstrDescriptor layout)
//
union Tcgen05InstrDescriptor {
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor() noexcept : desc_(0) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(uint32_t desc) noexcept
: desc_(desc) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor const &t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor(
Tcgen05InstrDescriptor &&t) noexcept
: desc_(t.desc_) {}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
CUTE_HOST_DEVICE constexpr Tcgen05InstrDescriptor &
operator=(Tcgen05InstrDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint32_t desc_;
uint16_t reg16_[2];
// Bitfield implementation mirrors cute::UMMA::InstrDescriptor
struct {
// bit [ 0, 2) : Sparse meta data id2
uint16_t sparse_id2_ : 2,
// bit [ 2, 3) : 0 = dense. 1 = sparse. Only valid for
// F32F16/S8/MXF8F6F4
sparse_flag_ : 1,
// bit [ 3, 4) : 0 = no saturate. 1 = saturate. Only valid for S8
saturate_ : 1,
// bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32
c_format_ : 2,
// padding
: 1,
// bit [ 7,10) : see UMMA format encoding
a_format_ : 3,
// bit [10,13) : see UMMA format encoding
b_format_ : 3,
// bit [13,14) : 0 = no negate. 1 = negate
a_negate_ : 1,
// bit [14,15) : 0 = no negate. 1 = negate
b_negate_ : 1,
// bit [15,16) : 0 = K-major. 1 = MN-major
a_major_ : 1;
// Upper 16 bits
uint16_t b_major_ : 1, // bit [16,17)
n_dim_ : 6, // bit [17,23) : 3 LSBs not included
: 1, // padding
m_dim_ : 5, // bit [24,29) : 4 LSBs not included
: 1, // padding
max_shift_ : 2; // bit [30,32)
} bitfield;
// Decay to a uint32_t
CUTE_HOST_DEVICE constexpr explicit operator uint32_t() const noexcept {
return desc_;
}
};
// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
for (int i = 0; i < size; i++) {
......@@ -326,8 +458,8 @@ TL_DEVICE void __sync_thread_partial() {
template <int layout_type = 0, int leading_byte_offset = 0,
int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
TL_DEVICE void initialize_wgmma_descriptor(GmmaDescriptor &descriptor,
T *start_address) {
descriptor.bitfield.start_address_ =
cute::cast_smem_ptr_to_uint(start_address) >> 4;
descriptor.bitfield.layout_type_ = layout_type;
......@@ -336,6 +468,23 @@ TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
}
template <typename T>
TL_DEVICE void
initialize_tcgen05_descriptor(Tcgen05SMemDescriptor &descriptor,
T *start_address, int leading_byte_offset,
int stride_byte_offset, int base_offset,
bool leading_is_absolute, int swizzle_mode) {
descriptor.bitfield.start_address_ =
static_cast<uint16_t>(cast_smem_ptr_to_uint(start_address) >> 4);
descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
descriptor.bitfield.version_ = 1;
descriptor.bitfield.base_offset_ = base_offset & 0x7;
descriptor.bitfield.lbo_mode_ = leading_is_absolute ? 1 : 0;
descriptor.bitfield.layout_type_ = swizzle_mode & 0x7;
}
template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
T offset) {
......
#pragma once
#include "../common.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <type_traits>
#include <utility>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
namespace detail {
template <class Impl> struct MmaImplTraits {
using DReg = std::remove_extent_t<typename Impl::DRegisters>;
using AReg = std::remove_extent_t<typename Impl::ARegisters>;
using BReg = std::remove_extent_t<typename Impl::BRegisters>;
using CReg = std::remove_extent_t<typename Impl::CRegisters>;
static constexpr int kDRegs = std::extent_v<typename Impl::DRegisters>;
static constexpr int kARegs = std::extent_v<typename Impl::ARegisters>;
static constexpr int kBRegs = std::extent_v<typename Impl::BRegisters>;
static constexpr int kCRegs = std::extent_v<typename Impl::CRegisters>;
};
template <class Impl, size_t... DIdx, size_t... AIdx, size_t... BIdx,
size_t... CIdx>
TL_DEVICE void
call_fma_impl(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c,
std::index_sequence<DIdx...>, std::index_sequence<AIdx...>,
std::index_sequence<BIdx...>, std::index_sequence<CIdx...>) {
Impl::fma(d[DIdx]..., a[AIdx]..., b[BIdx]..., c[CIdx]...);
}
template <class Impl>
TL_DEVICE void call_fma(typename MmaImplTraits<Impl>::DReg *d,
const typename MmaImplTraits<Impl>::AReg *a,
const typename MmaImplTraits<Impl>::BReg *b,
const typename MmaImplTraits<Impl>::CReg *c) {
call_fma_impl<Impl>(d, a, b, c,
std::make_index_sequence<MmaImplTraits<Impl>::kDRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kARegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kBRegs>{},
std::make_index_sequence<MmaImplTraits<Impl>::kCRegs>{});
}
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate>
struct MmaDispatcher {
using CRegType = void;
using ARegType = void;
using BRegType = void;
static TL_DEVICE void exec(CRegType *, const ARegType *, const BRegType *,
const CRegType *) {
static_assert(always_false_v<std::integral_constant<int, M>>,
"tl::mma_sync: unsupported configuration");
}
};
#define TL_DEFINE_MMA_DISPATCHER(ATypeEnum, BTypeEnum, CTypeEnum, MValue, \
NValue, KValue, TransAValue, TransBValue, \
SaturateValue, ImplType) \
template <> \
struct MmaDispatcher<DataType::ATypeEnum, DataType::BTypeEnum, \
DataType::CTypeEnum, MValue, NValue, KValue, \
TransAValue, TransBValue, SaturateValue> { \
using Impl = ImplType; \
using Traits = MmaImplTraits<Impl>; \
using CRegType = typename Traits::DReg; \
using ARegType = typename Traits::AReg; \
using BRegType = typename Traits::BReg; \
static_assert( \
std::is_same_v<typename Traits::DReg, typename Traits::CReg>, \
"tl::mma_sync requires matching accumulator/output regs"); \
static TL_DEVICE void exec(CRegType *d, const ARegType *a, \
const BRegType *b, const CRegType *c) { \
call_fma<Impl>(d, a, b, c); \
} \
};
// FP16 inputs (TN layout: A row-major, B column-major)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat16, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F16F16F16F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat16, kFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32F16F16F32_TN)
// BF16 inputs
TL_DEFINE_MMA_DISPATCHER(kBFloat16, kBFloat16, kFloat32, 16, 8, 16, false, true,
false, cute::SM80_16x8x16_F32BF16BF16F32_TN)
// INT8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt8, kInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S8S8S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt8, kUInt8, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U8U8S32_TN)
// INT4 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kInt4, kInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32S4S4S32_TN)
TL_DEFINE_MMA_DISPATCHER(kUInt4, kUInt4, kInt32, 16, 8, 32, false, true, false,
cute::SM80_16x8x32_S32U4U4S32_TN)
// FP8 inputs (k32)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E4M3E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e4m3, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E4M3E5M2F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E4M3F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e4m3, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E4M3F32_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat16, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F16E5M2E5M2F16_TN)
TL_DEFINE_MMA_DISPATCHER(kFloat8_e5m2, kFloat8_e5m2, kFloat32, 16, 8, 32, false,
true, false, cute::SM89_16x8x32_F32E5M2E5M2F32_TN)
#undef TL_DEFINE_MMA_DISPATCHER
} // namespace detail
template <DataType AType, DataType BType, DataType CType, int M, int N, int K,
bool TransA, bool TransB, bool Saturate = false>
TL_DEVICE void mma_sync(
typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA, TransB,
Saturate>::CRegType *c,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::ARegType *a,
const typename detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>::BRegType *b) {
using Dispatcher = detail::MmaDispatcher<AType, BType, CType, M, N, K, TransA,
TransB, Saturate>;
static_assert(!std::is_void_v<typename Dispatcher::CRegType>,
"tl::mma_sync: unsupported configuration");
Dispatcher::exec(c, a, b, c);
}
} // namespace tl
#pragma once
#include "../common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
#ifndef TL_ALWAYS_FALSE_V_DEFINED
#define TL_ALWAYS_FALSE_V_DEFINED
template <class> inline constexpr bool always_false_v = false;
#endif
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ss: unsupported accumulator type");
}
// TS variants: A from TMEM, B from SMEM (desc)
// Generic declaration: unsupported by default
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ts(uint32_t const & /*tmem_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ts: unsupported accumulator type");
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kBFloat16>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat16>(tmem_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kTensorFloat32>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kInt8>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e4m3>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, "
"{%5, %6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ts<DataType::kFloat8_e5m2>(
uint32_t const &tmem_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ts<DataType::kFloat8_e4m3>(tmem_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// F16/BF16 instruction kind (maps to kind::f16)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
// idescE upper 32 bits carry the instruction descriptor; lower 32 ignored for
// SS Load TMEM base from shared memory slot handled by caller
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// BF16 maps to the same f16-kind instruction
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// INT8 instruction kind
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
// FP8 family instruction kind (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, "
"%6, %7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val),
"r"(scalec), "r"(mask0), "r"(mask1), "r"(mask2), "r"(mask3));
}
}
template <>
TL_DEVICE void tcgen05mma_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ss<DataType::kFloat8_e4m3>(desc_a, desc_b, tmem_c, scalec,
desc_val, mask0, mask1, mask2, mask3);
}
// WS variants: tcgen05.mma.ws.cta_group::1.kind::xxx
// Generic declaration falls back to static assert
template <DataType C_type>
TL_DEVICE void
tcgen05mma_ws_ss(uint64_t const & /*desc_a*/, uint64_t const & /*desc_b*/,
uint32_t const & /*tmem_c*/, uint32_t const & /*scalec*/,
uint32_t const & /*desc_val*/, int const & /*mask0*/,
int const & /*mask1*/, int const & /*mask2*/,
int const & /*mask3*/) {
static_assert(
always_false_v<std::integral_constant<int, static_cast<int>(C_type)>>,
"tl::tcgen05mma_ws_ss: unsupported accumulator type");
}
// F16/BF16 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kBFloat16>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat16>(desc_a, desc_b, tmem_c, scalec, desc_val,
mask0, mask1, mask2, mask3);
}
// TF32 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kTensorFloat32>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::tf32 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// INT8 ws
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kInt8>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::i8 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
// FP8 ws (maps to f8f6f4)
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(desc_val), "r"(scalec));
}
}
template <>
TL_DEVICE void tcgen05mma_ws_ss<DataType::kFloat8_e5m2>(
uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scalec, uint32_t const &desc_val, int const &mask0,
int const &mask1, int const &mask2, int const &mask3) {
tcgen05mma_ws_ss<DataType::kFloat8_e4m3>(
desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3);
}
} // namespace tl
This diff is collapsed.
......@@ -67,6 +67,20 @@ template <int NumMma> TL_DEVICE void warpgroup_wait() {
cute::warpgroup_wait<NumMma>();
}
TL_DEVICE void warpgroup_fence_operand(uint32_t *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
TL_DEVICE void warpgroup_fence_operand(float *regs, int count) {
#pragma unroll
for (int i = 0; i < count; ++i) {
cute::warpgroup_fence_operand(regs[i]);
}
}
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
......
......@@ -6,6 +6,7 @@
#endif
#include "common.h"
#include <cute/arch/cluster_sm90.hpp>
namespace tl {
......@@ -59,12 +60,15 @@ inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
}
inline __device__ void amma_commit(uint64_t const *smem_ptr) {
// Wrapper for CUTLASS umma_arrive: elect one lane, then arrive the mbarrier
TL_DEVICE void tcgen05_mma_arrive(void const *smem_ptr) {
uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
"cluster.b64 [%0];"
:
: "r"(bar_intptr));
if (cute::elect_one_sync()) {
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
"cluster.b64 [%0];"
:
: "r"(bar_intptr));
}
}
} // namespace tl
\ No newline at end of file
} // namespace tl
......@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return false;
}
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor());
call->op.same_as(initialize_wgmma_descriptor()) ||
call->op.same_as(initialize_tcgen05_descriptor());
}
ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment