Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -241,14 +241,24 @@ TVM_DLL const Op &ptx_wgmma_ss();
/*!
* \brief tvm intrinsics for ptx tensor core wgmma instructions.
*
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix, bool
* a_is_k_major, bool b_is_k_major, StringImm a_dtype_abbrv, StringImm
* b_dtype_abbrv, StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr
* A_offset, Var B_descriptor, Var B_offset, Var C_data, Var C_offset, bool
* scale_out, bool scale_in_a, bool scale_in_b);
* void ptx_wgmma_rs(StringImm accum_dtype, StringImm wgmma_prefix,
* bool b_is_k_major, StringImm a_dtype_abbrv, StringImm b_dtype_abbrv,
* StringImm accum_dtype_abbrv, Var A_descriptor, PrimExpr A_offset, Var
* B_descriptor, Var B_offset, Var C_data, Var C_offset, bool scale_out,
* bool scale_in_a, bool scale_in_b);
*/
TVM_DLL const Op &ptx_wgmma_rs();
/*!
* \brief tvm intrinsic for tcgen05 mma shared-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ss();
/*!
* \brief tvm intrinsic for tcgen05 mma tensor-shared instructions.
*/
TVM_DLL const Op &ptx_tcgen05_mma_ts();
/*!
* \brief tvm intrinsics for initializing tensor memory
*
......@@ -265,6 +275,17 @@ TVM_DLL const Op &ptx_init_tensor_memory();
*/
TVM_DLL const Op &ptx_deallocate_tensor_memory();
/*!
* \brief tvm intrinsic for ptx tensor core mma instructions on SM70.
*
* void ptx_mma_sm70(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL const Op &ptx_mma_sm70();
/*!
* \brief tvm intrinsics for ldmatrix
*
......@@ -361,6 +382,14 @@ TVM_DLL const Op &warpgroup_commit_batch();
*/
TVM_DLL const Op &warpgroup_wait();
/*!
* \brief Fence accumulator operand registers for upcoming WGMMA operations
*
* warpgroup_fence_operand(dtype, ptr, offset, num_regs)
*
*/
TVM_DLL const Op &warpgroup_fence_operand();
/*!
* \brief Return the canonical lane index for the calling thread.
*
......@@ -494,7 +523,21 @@ TVM_DLL const Op &tl_shuffle_elect();
* This op is used to represent a descriptor initialization operation in
* tilelang.
*/
TVM_DLL const Op &initialize_descriptor();
TVM_DLL const Op &initialize_wgmma_descriptor();
/*!
* \brief tilelang intrinsic for initializing a descriptor buffer for
* tcgen05 mma.
*/
TVM_DLL const Op &initialize_tcgen05_descriptor();
/*!
* \brief tilelang intrinsic for committing UMMA (TCGEN05) barrier arrive.
*
* This op wraps the device-side arrive used to signal completion of MMA work
* to a shared-memory mbarrier. It mirrors CUTLASS's umma_arrive.
*/
TVM_DLL const Op &tcgen05_mma_arrive();
/*!
* \brief tilelang intrinsic for setting the start address of a descriptor
......@@ -505,6 +548,7 @@ TVM_DLL const Op &initialize_descriptor();
*/
TVM_DLL const Op &increase_descriptor_offset();
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
......@@ -513,6 +557,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/
TVM_DLL const Op &atomicadd_elem_op();
/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL const Op &device_assert();
/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL const Op &device_assert_with_msg();
} // namespace tl
} // namespace tvm
......
......@@ -130,7 +130,7 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
* @param vmap BufferMap used to resolve RegionOp buffers and ranges.
*/
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<CopyNode> node = make_object<CopyNode>();
ObjectPtr<CopyNode> node = tvm::ffi::make_object<CopyNode>();
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
......@@ -169,7 +169,7 @@ Copy::Copy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator owning the cloned CopyNode.
*/
TileOperator CopyNode::Clone() const {
auto op = make_object<CopyNode>(*this);
auto op = tvm::ffi::make_object<CopyNode>(*this);
if (par_op_.defined()) {
op->par_op_ = Downcast<ParallelOp>(par_op_->Clone());
}
......@@ -401,7 +401,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, T.analyzer, T.buffer_oob);
......@@ -793,7 +793,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
using namespace tvm::transform;
PassContext pass_ctx = PassContext::Current();
bool disable_tma_lower =
pass_ctx->GetConfig<bool>(kDisableTMALower, false).value();
pass_ctx->GetConfig<Bool>(kDisableTMALower, Bool(false)).value();
auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma,
T.layout_map, analyzer);
if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) {
......@@ -1504,7 +1504,12 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
}
auto inner_box_dim = as_const_int(desc.smem_box[0]);
ICHECK(inner_box_dim != nullptr);
if (inner_box_dim == nullptr) {
LOG(WARNING) << "inner_box_dim " << desc.smem_box[0]
<< " can only be a constant integer for TMA bulk copy, "
"fallback to normal copy";
return LowerNormalCopy(T, analyzer);
}
int instruction_dim = *inner_box_dim;
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) {
instruction_dim = 64 / src->dtype.bytes();
......@@ -1722,7 +1727,8 @@ Array<PrimExpr> TMADesc::EncodeCallArgs() const {
* @param vmap Mapping from original buffer variables to actual Buffer objects.
*/
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<Conv2DIm2ColOpNode> node = make_object<Conv2DIm2ColOpNode>();
ObjectPtr<Conv2DIm2ColOpNode> node =
tvm::ffi::make_object<Conv2DIm2ColOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->nhw_step = args[2];
......@@ -1747,7 +1753,7 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode.
*/
TileOperator Conv2DIm2ColOpNode::Clone() const {
auto op = make_object<Conv2DIm2ColOpNode>(*this);
auto op = tvm::ffi::make_object<Conv2DIm2ColOpNode>(*this);
return Conv2DIm2ColOp(op);
}
......@@ -1973,9 +1979,9 @@ TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
CopyNode::RegisterReflection();
Conv2DIm2ColOpNode::RegisterReflection();
});
}
} // namespace tl
} // namespace tvm
......@@ -101,8 +101,7 @@ public:
};
uint8_t eviction_policy; // Policy for cache eviction
static constexpr const char *_type_key = "tl.Copy";
TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -114,23 +113,6 @@ public:
.def_ro("coalesced_width", &CopyNode::coalesced_width);
}
bool SEqualReduce(const CopyNode *other, SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) &&
equal(coalesced_width, other->coalesced_width);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(src_range);
hash_reduce(dst_range);
hash_reduce(coalesced_width);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower the copy operator to a TIR statement.
* \param T Arguments for lowering.
......@@ -291,7 +273,7 @@ protected:
class Copy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Copy, TileOperator, CopyNode);
/*!
* \brief Constructor.
......@@ -323,8 +305,8 @@ public:
PrimExpr nhw_step; // Step size in NHW dimensions
PrimExpr c_step; // Step size in channel dimension
static constexpr const char *_type_key = "tl.Conv2DIm2Col";
TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Conv2DIm2Col", Conv2DIm2ColOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -338,26 +320,6 @@ public:
.def_ro("eviction_policy", &Conv2DIm2ColOpNode::eviction_policy);
}
bool SEqualReduce(const Conv2DIm2ColOpNode *other,
SEqualReducer equal) const {
return equal(src, other->src) && equal(dst, other->dst) &&
equal(stride, other->stride) && equal(padding, other->padding) &&
equal(dilation, other->dilation) && equal(kernel, other->kernel) &&
equal(eviction_policy, other->eviction_policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(src);
hash_reduce(dst);
hash_reduce(stride);
hash_reduce(padding);
hash_reduce(dilation);
hash_reduce(kernel);
hash_reduce(eviction_policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
/*!
* \brief Lower to TIR statement.
*/
......@@ -378,7 +340,7 @@ public:
class Conv2DIm2ColOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator,
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator,
Conv2DIm2ColOpNode);
TVM_DLL Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
......
......@@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"
namespace tvm {
namespace tl {
......@@ -60,9 +61,32 @@ using namespace tir;
* of bounds.
*/
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = make_object<FillNode>();
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
if (args[0]->IsInstance<BufferLoadNode>()) {
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}
// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;
// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
......@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
......@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
}
}
data_ = std::move(node);
......@@ -117,7 +147,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator that owns the copied FillNode.
*/
TileOperator FillNode::Clone() const {
auto op = make_object<FillNode>(*this);
auto op = tvm::ffi::make_object<FillNode>(*this);
return Fill(op);
}
......@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
......@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
}
}
......@@ -226,7 +258,7 @@ TIR_REGISTER_TL_OP(Fill, fill)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FillNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -20,8 +20,7 @@ public:
tir::Buffer dst; ///< Destination buffer to fill
PrimExpr value; ///< Value to fill with
Array<Range> region; ///< Region to fill within the buffer
static constexpr const char *_type_key = "tl.Fill";
TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fill", FillNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const;
......@@ -35,19 +34,6 @@ public:
.def_ro("region", &FillNode::region);
}
bool SEqualReduce(const FillNode *other, SEqualReducer equal) const {
return equal(dst, other->dst) && equal(value, other->value) &&
equal(region, other->region);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dst);
hash_reduce(value);
hash_reduce(region);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TileOperator Clone() const;
private:
......@@ -58,7 +44,7 @@ private:
/// Wrapper class for fill operations
class Fill : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode);
TVM_DLL Fill(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -33,7 +33,7 @@ using namespace tir;
* Buffer.
*/
FinalizeReducerOp::FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap) {
auto node = make_object<FinalizeReducerOpNode>();
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>();
node->reducer = vmap[GetVarFromAccessPtr(args[0])];
node->op = (ReducerOpType)*as_const_int(args[1]);
data_ = std::move(node);
......@@ -152,7 +152,7 @@ LayoutMap FinalizeReducerOpNode::InferLayout(const LayoutInferArgs &T,
* @return TileOperator A TileOperator that contains a deep copy of this node.
*/
TileOperator FinalizeReducerOpNode::Clone() const {
auto node = make_object<FinalizeReducerOpNode>(*this);
auto node = tvm::ffi::make_object<FinalizeReducerOpNode>(*this);
return TileOperator(node);
}
......@@ -161,6 +161,6 @@ TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ FinalizeReducerOpNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { FinalizeReducerOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -27,8 +27,8 @@ public:
tir::Buffer reducer;
ReducerOpType op;
static constexpr const char *_type_key = "tl.FinalizeReducerOp";
TVM_DECLARE_FINAL_OBJECT_INFO(FinalizeReducerOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.FinalizeReducerOp",
FinalizeReducerOpNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -37,18 +37,6 @@ public:
.def_ro("op", &FinalizeReducerOpNode::op);
}
bool SEqualReduce(const FinalizeReducerOpNode *other,
SEqualReducer equal) const {
return equal(reducer, other->reducer) && equal(op, other->op);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(reducer);
hash_reduce(op);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -58,7 +46,7 @@ public:
class FinalizeReducerOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(FinalizeReducerOp, TileOperator,
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator,
FinalizeReducerOpNode);
TVM_DLL FinalizeReducerOp(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
......
......@@ -12,77 +12,14 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace tvm {
namespace tl {
using namespace tir;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
......@@ -111,42 +48,130 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap[var];
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmNode> node = make_object<GemmNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<PrimExpr>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
ObjectPtr<GemmNode> node = tvm::ffi::make_object<GemmNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[3].as<Bool>().value();
node->transB_ = args[4].as<Bool>().value();
node->m_ = args[5].as<IntImm>().value()->value;
node->n_ = args[6].as<IntImm>().value()->value;
node->k_ = args[7].as<IntImm>().value()->value;
node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clearAccum_ = args[9].as<PrimExpr>().value();
node->strideA_ = args[10].as<IntImm>().value()->value;
node->strideB_ = args[11].as<IntImm>().value()->value;
node->offsetA_ = args[12].as<IntImm>().value()->value;
node->offsetB_ = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
node->kPack_ = args[14].as<IntImm>().value()->value;
if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarptr = args[16];
if (node->mbarptr.as<CallNode>()) {
node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->mbar = std::nullopt;
node->mbar_ = std::nullopt;
}
node->C_coords = Array<PrimExpr>(
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node);
}
......@@ -160,46 +185,45 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmNode::Clone() const {
auto op = make_object<GemmNode>(*this);
auto op = tvm::ffi::make_object<GemmNode>(*this);
return Gemm(op);
}
bool GemmNode::AllowTCGEN5MMA(Target target) const {
bool GemmNode::allowTcgen5Mma(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
a_.scope() == "shared.tmem") &&
(b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
c_.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
}
bool GemmNode::AllowWGMMA(int block_size, Target target) const {
bool GemmNode::allowWgmma(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
checkWgmma();
}
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
GemmInst GemmNode::getGemmInst(int block_size, Target target) const {
if (allowTcgen5Mma(target)) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
} else if (allowWgmma(block_size, target)) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
ICHECK(0) << "Unsupported target for gemm: " << target;
return GemmInst::kMMA;
}
}
std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
std::pair<int, int> GemmWarpPolicyNode::computeWarpPartition(
int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
int num_warps = block_size / TargetGetWarpSize(target);
if (gemm_inst == GemmInst::kTCGEN5MMA) {
......@@ -208,7 +232,10 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
int kNPerWarp = 8; // Columns processed by a single warp
if (TargetIsVolta(target)) {
kNPerWarp = 16;
}
ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M;
ICHECK(N % kNPerWarp == 0)
......@@ -408,51 +435,52 @@ std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
bool GemmNode::checkWgmma() const {
if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Float(32)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::BFloat(16) &&
b_->dtype == DataType::BFloat(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Int(32)) {
if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else {
......@@ -476,8 +504,8 @@ bool GemmNode::CheckWGMMA() const {
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
......@@ -502,56 +530,61 @@ static int GetArchInt(Target target) {
*/
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
GemmInst gemm_inst = getGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
// Build access pointers from regions locally
PrimExpr Aptr = MakeAccessPtrFromRegion(aRegion_, /*r*/ 1);
PrimExpr Bptr = MakeAccessPtrFromRegion(bRegion_, /*r*/ 1);
PrimExpr Cptr = MakeAccessPtrFromRegion(cRegion_, /*rw*/ 3);
std::stringstream ss;
std::string op_name;
if (gemm_inst == GemmInst::kTCGEN5MMA) {
auto [can_use_tcgen5mma, meta] =
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
ICHECK(can_use_tcgen5mma);
ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared");
ICHECK(C.scope() == "shared.tmem");
ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA";
if (A.scope() == "shared.tmem") {
ICHECK(b_.scope() == "shared.dyn" || b_.scope() == "shared");
ICHECK(c_.scope() == "shared.tmem");
ICHECK(mbar_.has_value()) << "mbar must be provided for TCGEN5MMA";
if (a_.scope() == "shared.tmem") {
op_name = "tl::tcgen5mma_gemm_ts";
} else if (A.scope() == "shared.dyn" || A.scope() == "shared") {
} else if (a_.scope() == "shared.dyn" || a_.scope() == "shared") {
op_name = "tl::tcgen5mma_gemm_ss";
} else {
ICHECK(0)
<< "Unsupported A scope for TCGEN5MMA: "
<< A.scope(); // If this is triggered, it means Tilelang has bugs.
<< a_.scope(); // If this is triggered, it means Tilelang has bugs.
}
ICHECK(wg_wait == -1)
ICHECK(wgWait_ == -1)
<< "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
"use "
"wg_wait = -1 and manually synchronize with mbarrier.";
std::string accum_dtype = "";
if (C->dtype.is_float()) {
if (C->dtype.bits() == 32) {
if (c_->dtype.is_float()) {
if (c_->dtype.bits() == 32) {
accum_dtype = "float";
}
}
ICHECK(!accum_dtype.empty())
<< "Unsupported C dtype for TCGEN5MMA: " << C->dtype;
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
<< "Unsupported C dtype for TCGEN5MMA: " << c_->dtype;
ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
ss << trans_A << ", " << trans_B << ", ";
ss << transA_ << ", " << transB_ << ", ";
ss << accum_dtype;
ss << ">";
auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
auto C_buffer = T.buffer_remap.count(c_) ? T.buffer_remap[c_] : c_;
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, C_coords));
new_args.push_back(mbarptr);
new_args.push_back(clear_accum);
new_args.push_back(BufferLoad(C_buffer, cCoords_));
new_args.push_back(mbarPtr_);
new_args.push_back(clearAccum_);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
// Since TCGEN5MMA atoms provided by CUTLASS always have an internal
......@@ -576,47 +609,49 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
}
if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
if (a_.scope() == "local.fragment") {
ICHECK(b_.scope() != "local.fragment");
ICHECK(!transA_)
<< "gemm_rs requires the A operand to be in non-transposed layout.";
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
} else if (b_.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
} else {
op_name = "tl::gemm_ss";
}
ICHECK(C.scope() == "local.fragment");
ICHECK(c_.scope() == "local.fragment");
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
auto clear_accum_bool = clear_accum.as<Bool>();
ss << transA_ << ", " << transB_;
auto clear_accum_bool = clearAccum_.as<Bool>();
ICHECK(clear_accum_bool.has_value())
<< "clear_accum must be a constant Bool type, got " << clear_accum;
<< "clear_accum must be a constant Bool type, got " << clearAccum_;
ss << ", " << bool(clear_accum_bool.value());
if (TargetIsCuda(T.target) && (GetArchInt(T.target) >= 75)) {
ss << ", " << stride_A << ", " << stride_B;
ss << ", " << offset_A << ", " << offset_B;
ss << ", " << strideA_ << ", " << strideB_;
ss << ", " << offsetA_ << ", " << offsetB_;
}
if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
ss << ", " << kPack_;
} else if (TargetIsHopper(T.target)) {
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
}
// Emit wg_wait if necessary
if (TargetIsHopper(T.target)) {
if (wg_wait != 0) {
ss << ", " << wg_wait;
if (wgWait_ != 0) {
ss << ", " << wgWait_;
}
} else if (TargetIsSm100(T.target)) {
// NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
// but all threads need to wait, so we emit another statement for cases
// where wg_wait == 0.
ICHECK(wg_wait == 0 || wg_wait == -1)
ICHECK(wgWait_ == 0 || wgWait_ == -1)
<< "wg_wait must be 0 or -1 for Sm100";
} else {
ICHECK(wg_wait == 0)
ICHECK(wgWait_ == 0)
<< "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
}
ss << ">";
......@@ -652,151 +687,152 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
LayoutMap results;
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
GemmInst gemm_inst = getGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
if (TargetIsVolta(T.target)) {
ICHECK(C.scope() == "local.fragment")
ICHECK(c_.scope() == "local.fragment")
<< "Volta gemm only supports C in local.fragment scope, got "
<< C.scope();
<< c_.scope();
auto fragment = makeGemmVoltaFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
results.Set(a_, makeGemmVoltaABLayout(*as_const_int(a_->shape[dim_A - 2]),
*as_const_int(a_->shape[dim_A - 1]),
true, !transA_));
} else if (a_.scope() == "local.fragment") {
ICHECK(transA_ == false);
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]),
true, !trans_A));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment = makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n);
results.Set(A, fragment->BindThreadRange(thread_range));
makeGemmVoltaFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n);
results.Set(a_, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
int dim_B = B->shape.size();
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]),
false, trans_B));
ICHECK(b_.scope() == "shared" || b_.scope() == "shared.dyn");
int dim_B = b_->shape.size();
results.Set(b_, makeGemmVoltaABLayout(*as_const_int(b_->shape[dim_B - 2]),
*as_const_int(b_->shape[dim_B - 1]),
false, transB_));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target) ||
(TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
ICHECK(C.scope() == "local.fragment")
<< "MMA only supports C in local.fragment scope, got " << C.scope();
ICHECK(c_.scope() == "local.fragment")
<< "MMA only supports C in local.fragment scope, got " << c_.scope();
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A,
makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), !trans_A));
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
a_->dtype.bits(), !transA_));
} else if (a_.scope() == "local.fragment") {
auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
a_->dtype.bits(), transA_);
results.Set(a_, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B,
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
results.Set(b_,
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B));
} else if (B.scope() == "local.fragment") {
b_->dtype.bits(), transB_));
} else if (b_.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
results.Set(b_, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
} else if (TargetIsHopper(T.target)) {
ICHECK(C.scope() == "local.fragment")
ICHECK(c_.scope() == "local.fragment")
<< (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
<< "only supports C in local.fragment scope, got " << C.scope();
auto fragment =
gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
<< "only supports C in local.fragment scope, got " << c_.scope();
auto fragment = gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
n_ / warp_n, c_->dtype.bits())
: makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
transA_ ? 4 * mat_continuous / warp_m : mat_continuous;
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), !trans_A)
a_->dtype.bits(), !transA_)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), !trans_A);
results.Set(A, ABLayout);
a_->dtype.bits(), !transA_);
results.Set(a_, ABLayout);
} else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
auto fragment = makeGemmFragmentA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
a_->dtype.bits(), transA_);
results.Set(a_, fragment->BindThreadRange(thread_range));
}
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
transB_ ? mat_continuous : mat_continuous / warp_n;
auto ABLayout =
gemm_inst == GemmInst::kWGMMA
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B)
b_->dtype.bits(), transB_)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B);
results.Set(B, ABLayout);
b_->dtype.bits(), transB_);
results.Set(b_, ABLayout);
} else {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
results.Set(b_, fragment->BindThreadRange(thread_range));
}
} else if (gemm_inst == GemmInst::kTCGEN5MMA) {
ICHECK(C.scope() == "shared.tmem")
<< "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope();
ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared")
ICHECK(c_.scope() == "shared.tmem")
<< "TCGEN5MMA only supports C in shared.tmem scope, got " << c_.scope();
ICHECK(a_.scope() == "shared.dyn" || a_.scope() == "shared")
<< "Current TCGEN5MMA only supports A in shared.dyn scope";
auto [can_use_tcgen5mma, meta] =
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype);
ICHECK(can_use_tcgen5mma);
{
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_, makeGemmABLayoutSm100(mat_stride, mat_continuous,
mat_continuous, a_->dtype.bits(),
transA_ ? 1 : 2));
}
{
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
const int64_t continuity = mat_continuous;
results.Set(B,
results.Set(b_,
makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
b_->dtype.bits(), transB_ ? 2 : 1));
}
{
Layout res;
IterVar i = make_itervar("i", M);
IterVar j = make_itervar("j", N);
ICHECK(M % meta.atom_m == 0);
IterVar i = make_itervar("i", m_);
IterVar j = make_itervar("j", n_);
ICHECK(m_ % meta.atom_m == 0);
PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
FloorDiv(j, meta.atom_n) * (M / meta.atom_m);
FloorDiv(j, meta.atom_n) * (m_ / meta.atom_m);
PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
PrimExpr aj = FloorMod(j, meta.atom_n);
if (meta.atom_m == 128) {
......@@ -822,46 +858,46 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
} else {
ICHECK(0);
}
results.Set(C, res);
results.Set(c_, res);
}
} else if (TargetIsCDNA(T.target)) {
ICHECK(C.scope() == "local.fragment")
ICHECK(c_.scope() == "local.fragment")
<< "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<< C.scope();
<< c_.scope();
if (TargetIsDCU(T.target)) {
auto fragment =
makeGemmFragmentCDCU(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
makeGemmFragmentCDCU(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
} else {
auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
makeGemmFragmentCCDNA(m_, n_, m_ / warp_m, n_ / warp_n, c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
}
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(A->shape[dim_A - 2]),
*as_const_int(A->shape[dim_A - 1]), A->dtype.bits(), kPack);
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
auto fragment = makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), kPack, trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
*as_const_int(a_->shape[dim_A - 2]),
*as_const_int(a_->shape[dim_A - 1]), a_->dtype.bits(), kPack_);
results.Set(a_, shared_layout);
} else if (a_.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentACDNA(m_, n_, k_, m_ / warp_m, n_ / warp_n,
a_->dtype.bits(), kPack_, transA_);
results.Set(a_, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
auto shared_layout = makeGemmABLayoutCDNA(
*as_const_int(B->shape[dim_B - 2]),
*as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
*as_const_int(b_->shape[dim_B - 2]),
*as_const_int(b_->shape[dim_B - 1]), b_->dtype.bits(), kPack_);
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
results.Set(b_, shared_layout);
} else if (b_.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
makeGemmFragmentB(m_, n_, k_, m_ / warp_m, n_ / warp_n, transB_);
results.Set(b_, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -880,18 +916,17 @@ TIR_REGISTER_TL_OP(Gemm, gemm)
TVM_REGISTER_OP("tl.GemmWarpPolicy")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "GemmWarpPolicy");
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
GemmNode::RegisterReflection();
GemmWarpPolicyNode::RegisterReflection();
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
[](GemmWarpPolicy policy, int M, int N, int block_size,
Target target, GemmInst gemm_inst) {
policy->ComputeWarpPartition(M, N, block_size, target,
policy->computeWarpPartition(M, N, block_size, target,
gemm_inst);
return;
});
});
}
} // namespace tl
} // namespace tvm
......@@ -30,8 +30,7 @@ public:
mutable int n_warp{0};
int policy_type;
static constexpr const char *_type_key = "tl.GemmWarpPolicy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmWarpPolicyNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmWarpPolicy", GemmWarpPolicyNode, Object);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -41,22 +40,7 @@ public:
.def_ro("n_warp", &GemmWarpPolicyNode::n_warp);
}
bool SEqualReduce(const GemmWarpPolicyNode *other,
SEqualReducer equal) const {
return equal(policy_type, other->policy_type) &&
equal(m_warp, other->m_warp) && equal(n_warp, other->n_warp);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy_type);
hash_reduce(m_warp);
hash_reduce(n_warp);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target,
GemmInst gemm_inst) const;
......@@ -74,22 +58,23 @@ public:
class GemmWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmWarpPolicy, ObjectRef, GemmWarpPolicyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmWarpPolicy, ObjectRef,
GemmWarpPolicyNode);
explicit GemmWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int policy_type) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
......@@ -99,89 +84,48 @@ public:
class GemmNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
bool checkWgmma() const;
tir::Buffer a_, b_, c_;
// BufferRegion for A, B and C
BufferRegion aRegion_, bRegion_, cRegion_;
bool transA_, transB_;
int m_, n_, k_;
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
PrimExpr mbarptr;
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode);
int kPack_ = 1;
int wgWait_ = 0;
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
mutable GemmWarpPolicy policy_;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Gemm", GemmNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmNode>()
.def_ro("A", &GemmNode::A)
.def_ro("B", &GemmNode::B)
.def_ro("C", &GemmNode::C)
.def_ro("Aptr", &GemmNode::Aptr)
.def_ro("Bptr", &GemmNode::Bptr)
.def_ro("Cptr", &GemmNode::Cptr)
.def_ro("trans_A", &GemmNode::trans_A)
.def_ro("trans_B", &GemmNode::trans_B)
.def_ro("M", &GemmNode::M)
.def_ro("N", &GemmNode::N)
.def_ro("K", &GemmNode::K)
.def_ro("stride_A", &GemmNode::stride_A)
.def_ro("stride_B", &GemmNode::stride_B)
.def_ro("offset_A", &GemmNode::offset_A)
.def_ro("offset_B", &GemmNode::offset_B)
.def_ro("clear_accum", &GemmNode::clear_accum)
.def_ro("kPack", &GemmNode::kPack)
.def_ro("wg_wait", &GemmNode::wg_wait)
.def_ro("policy", &GemmNode::policy);
}
bool SEqualReduce(const GemmNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
.def_ro("a", &GemmNode::a_)
.def_ro("b", &GemmNode::b_)
.def_ro("c", &GemmNode::c_)
.def_ro("aRegion", &GemmNode::aRegion_)
.def_ro("bRegion", &GemmNode::bRegion_)
.def_ro("cRegion", &GemmNode::cRegion_)
.def_ro("transA", &GemmNode::transA_)
.def_ro("transB", &GemmNode::transB_)
.def_ro("m", &GemmNode::m_)
.def_ro("n", &GemmNode::n_)
.def_ro("k", &GemmNode::k_)
.def_ro("strideA", &GemmNode::strideA_)
.def_ro("strideB", &GemmNode::strideB_)
.def_ro("offsetA", &GemmNode::offsetA_)
.def_ro("offsetB", &GemmNode::offsetB_)
.def_ro("clearAccum", &GemmNode::clearAccum_)
.def_ro("kPack", &GemmNode::kPack_)
.def_ro("wgWait", &GemmNode::wgWait_)
.def_ro("policy", &GemmNode::policy_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
......@@ -190,16 +134,16 @@ public:
TileOperator Clone() const;
private:
GemmInst GetGemmInst(int block_size, Target target) const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
GemmInst getGemmInst(int block_size, Target target) const;
bool allowTcgen5Mma(Target target) const;
bool allowWgmma(int block_size, Target target) const;
mutable bool completed_ = false;
};
class Gemm : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode);
TVM_DLL Gemm(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -12,13 +12,101 @@
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
#include "region.h"
#include "tcgen5_meta.h"
namespace tvm {
namespace tl {
using namespace tir;
// Normalize a GEMM argument (BufferRegion/BufferLoad/tvm_access_ptr/tl.region)
// to BufferRegion
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in GEMM region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in GEMM region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
// builtin.tvm_access_ptr(...) — map var to Buffer and take full region
if (call->op.same_as(builtin::tvm_access_ptr())) {
Var var = Downcast<Var>(call->args[1]);
Buffer buf = vmap.at(var);
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
}
LOG(FATAL) << "Unsupported GEMM argument for BufferRegion: " << arg;
throw; // Unreachable, keeps compiler happy
}
// Build a tvm_access_ptr(handle) to the start of the 2D tile within a
// BufferRegion. Offset is computed from all but the last two dimensions; extent
// is the product of the last two extents. rw_mask: 1=read, 2=write,
// 3=readwrite.
static PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
ICHECK(ndim >= 2) << "GEMM expects buffers with at least 2 dims";
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
PrimExpr offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
PrimExpr extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
......@@ -48,34 +136,43 @@ using namespace tir;
* performed here.
*/
GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmPyNode> node = make_object<GemmPyNode>();
node->Aptr = args[0];
node->Bptr = args[1];
node->Cptr = args[2];
node->A = vmap[GetVarFromAccessPtr(node->Aptr)];
node->B = vmap[GetVarFromAccessPtr(node->Bptr)];
node->C = vmap[GetVarFromAccessPtr(node->Cptr)];
node->trans_A = args[3].as<Bool>().value();
node->trans_B = args[4].as<Bool>().value();
node->M = args[5].as<IntImm>().value()->value;
node->N = args[6].as<IntImm>().value()->value;
node->K = args[7].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clear_accum = args[9].as<PrimExpr>().value();
node->stride_A = args[10].as<IntImm>().value()->value;
node->stride_B = args[11].as<IntImm>().value()->value;
node->offset_A = args[12].as<IntImm>().value()->value;
node->offset_B = args[13].as<IntImm>().value()->value;
ObjectPtr<GemmPyNode> node = tvm::ffi::make_object<GemmPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->bRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->cRegion_ = NormalizeToBufferRegion(args[2], vmap);
node->a_ = node->aRegion_->buffer;
node->b_ = node->bRegion_->buffer;
node->c_ = node->cRegion_->buffer;
node->transA_ = args[3].as<Bool>().value();
node->transB_ = args[4].as<Bool>().value();
node->m_ = args[5].as<IntImm>().value()->value;
node->n_ = args[6].as<IntImm>().value()->value;
node->k_ = args[7].as<IntImm>().value()->value;
node->policy_ = GemmWarpPolicy(args[8].as<IntImm>().value()->value);
node->clearAccum_ = args[9].as<PrimExpr>().value();
node->strideA_ = args[10].as<IntImm>().value()->value;
node->strideB_ = args[11].as<IntImm>().value()->value;
node->offsetA_ = args[12].as<IntImm>().value()->value;
node->offsetB_ = args[13].as<IntImm>().value()->value;
if (args.size() > 14) {
node->kPack = args[14].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
node->kPack_ = args[14].as<IntImm>().value()->value;
if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
node->wgWait_ = args[15].as<IntImm>().value()->value;
}
node->mbarPtr_ = args[16];
if (node->mbarPtr_.as<CallNode>()) {
node->mbar_ = vmap[GetVarFromAccessPtr(node->mbarPtr_)];
} else {
node->mbar_ = std::nullopt;
}
node->cCoords_ = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node);
}
......@@ -88,20 +185,41 @@ GemmPy::GemmPy(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A Gemm operator that owns a copy of this node.
*/
TileOperator GemmPyNode::Clone() const {
auto op = make_object<GemmPyNode>(*this);
auto op = tvm::ffi::make_object<GemmPyNode>(*this);
return GemmPy(op);
}
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
bool GemmPyNode::allowTcgen5Mma(Target target) const {
return TargetIsSm100(target) &&
((a_.scope() == "shared.dyn" || a_.scope() == "shared" ||
a_.scope() == "shared.tmem") &&
(b_.scope() == "shared.dyn" || b_.scope() == "shared") &&
c_.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(m_, n_, k_, a_->dtype, c_->dtype).first;
}
bool GemmPyNode::allowWgmma(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->m_ >= 64) && (num_warps % 4 == 0) &&
checkWgmma();
}
GemmInst GemmPyNode::getGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = allowTcgen5Mma(target);
bool allow_wgmma = allowWgmma(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
......@@ -140,51 +258,52 @@ GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmPyNode::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
bool GemmPyNode::checkWgmma() const {
if (b_.scope() != "shared.dyn" && b_.scope() != "shared") {
return false;
}
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Float(32)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::BFloat(16) &&
b_->dtype == DataType::BFloat(16))
return k_ % 16 == 0;
else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
} else if (c_->dtype == DataType::Int(32)) {
if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::Int(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::Int(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype == DataType::UInt(8) && b_->dtype == DataType::UInt(8))
return (!transA_) && transB_ && k_ % 32 == 0;
else
return false;
} else {
......@@ -208,8 +327,8 @@ bool GemmPyNode::CheckWGMMA() const {
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
auto s = target->GetAttr<tvm::ffi::String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
......@@ -221,18 +340,19 @@ static int GetArchInt(Target target) {
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
GemmInst gemm_inst = getGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
policy_->computeWarpPartition(m_, n_, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(GetRef<GemmPy>(this), T.layout_map, T.target,
T.thread_bounds, T.thread_var));
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmPy>(this), T.layout_map,
T.target, T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.defined());
auto global_symbol =
prim_func->attrs.GetAttr<tvm::ffi::String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
......@@ -265,7 +385,15 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(GetRef<GemmPy>(this), T.target, T.thread_bounds));
(*f)(tvm::ffi::GetRef<GemmPy>(this), T.target, T.thread_bounds));
// Bind all fragment layouts with the provided thread range
for (auto kv : results) {
const Buffer &buf = kv.first;
const Layout &layout = kv.second;
if (auto frag = layout.as<Fragment>()) {
results.Set(buf, frag.value()->BindThreadRange(T.thread_bounds));
}
}
} else {
LOG(FATAL) << "No infer layout function found for gemm_py";
}
......@@ -279,15 +407,41 @@ TIR_REGISTER_TL_OP(GemmPy, gemm_py)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmPyNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { GemmPyNode::RegisterReflection(); }
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmPyGemmInst",
[](GemmPy gemm_py, int block_size, Target target) {
return gemm_py->GetGemmInst(block_size, target);
return gemm_py->getGemmInst(block_size, target);
});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tl.get_tcgen5_mma_meta",
[](int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
auto [success, meta] = GetTCGEN5MMAMeta(M, N, K, ab_dtype, c_dtype);
Array<Integer> result;
if (success) {
result.push_back(Integer(meta.atom_m));
result.push_back(Integer(meta.atom_n));
result.push_back(Integer(meta.atom_k));
}
return result;
});
});
refl::GlobalDef().def(
"tl.get_tcgen5_instr_desc",
[](int atom_m, int atom_n, int atom_k, DataType ab_dtype,
DataType c_dtype, bool a_is_k_major, bool b_is_k_major, int scale_in_a,
int scale_in_b) {
uint32_t desc = GetTCGEN5InstrDesc(atom_m, atom_n, atom_k, ab_dtype,
c_dtype, a_is_k_major, b_is_k_major,
scale_in_a, scale_in_b);
return Integer(static_cast<int64_t>(desc));
});
}
} // namespace tl
} // namespace tvm
......@@ -18,87 +18,54 @@ using namespace tir;
class GemmPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, B, C;
// pointer to the A, B, C
PrimExpr Aptr, Bptr, Cptr;
bool trans_A, trans_B;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
bool checkWgmma() const;
bool allowTcgen5Mma(Target target) const;
bool allowWgmma(int block_size, Target target) const;
tir::Buffer a_, b_, c_;
// BufferRegion for A, B and C
BufferRegion aRegion_, bRegion_, cRegion_;
bool transA_, transB_;
int m_, n_, k_;
int strideA_, strideB_;
int offsetA_, offsetB_;
PrimExpr clearAccum_ = const_false();
PrimExpr mbarPtr_;
std::optional<tir::Buffer> mbar_; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> cCoords_;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
mutable GemmWarpPolicy policy;
int kPack_ = 1;
int wgWait_ = 0;
mutable GemmWarpPolicy policy_;
static constexpr const char *_type_key = "tl.GemmPy";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmPyNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmPy", GemmPyNode, TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmPyNode>()
.def_ro("A", &GemmPyNode::A)
.def_ro("B", &GemmPyNode::B)
.def_ro("C", &GemmPyNode::C)
.def_ro("Aptr", &GemmPyNode::Aptr)
.def_ro("Bptr", &GemmPyNode::Bptr)
.def_ro("Cptr", &GemmPyNode::Cptr)
.def_ro("trans_A", &GemmPyNode::trans_A)
.def_ro("trans_B", &GemmPyNode::trans_B)
.def_ro("M", &GemmPyNode::M)
.def_ro("N", &GemmPyNode::N)
.def_ro("K", &GemmPyNode::K)
.def_ro("stride_A", &GemmPyNode::stride_A)
.def_ro("stride_B", &GemmPyNode::stride_B)
.def_ro("offset_A", &GemmPyNode::offset_A)
.def_ro("offset_B", &GemmPyNode::offset_B)
.def_ro("clear_accum", &GemmPyNode::clear_accum)
.def_ro("kPack", &GemmPyNode::kPack)
.def_ro("wg_wait", &GemmPyNode::wg_wait)
.def_ro("policy", &GemmPyNode::policy);
.def_ro("a", &GemmPyNode::a_)
.def_ro("b", &GemmPyNode::b_)
.def_ro("c", &GemmPyNode::c_)
.def_ro("aRegion", &GemmPyNode::aRegion_)
.def_ro("bRegion", &GemmPyNode::bRegion_)
.def_ro("cRegion", &GemmPyNode::cRegion_)
.def_ro("transA", &GemmPyNode::transA_)
.def_ro("transB", &GemmPyNode::transB_)
.def_ro("m", &GemmPyNode::m_)
.def_ro("n", &GemmPyNode::n_)
.def_ro("k", &GemmPyNode::k_)
.def_ro("strideA", &GemmPyNode::strideA_)
.def_ro("strideB", &GemmPyNode::strideB_)
.def_ro("offsetA", &GemmPyNode::offsetA_)
.def_ro("offsetB", &GemmPyNode::offsetB_)
.def_ro("clearAccum", &GemmPyNode::clearAccum_)
.def_ro("mbarPtr", &GemmPyNode::mbarPtr_)
.def_ro("cCoords", &GemmPyNode::cCoords_)
.def_ro("kPack", &GemmPyNode::kPack_)
.def_ro("wgWait", &GemmPyNode::wgWait_)
.def_ro("policy", &GemmPyNode::policy_);
}
bool SEqualReduce(const GemmPyNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(Aptr, other->Aptr) && equal(Bptr, other->Bptr) &&
equal(Cptr, other->Cptr) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
equal(policy, other->policy);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(Aptr);
hash_reduce(Bptr);
hash_reduce(Cptr);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(stride_A);
hash_reduce(stride_B);
hash_reduce(offset_A);
hash_reduce(offset_B);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
hash_reduce(policy);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -106,7 +73,7 @@ public:
TileOperator Clone() const;
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
GemmInst getGemmInst(int block_size, Target target) const;
private:
mutable bool completed_ = false;
......@@ -114,7 +81,7 @@ private:
class GemmPy : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmPy, TileOperator, GemmPyNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode);
TVM_DLL GemmPy(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -18,14 +18,14 @@
namespace tvm {
namespace tl {
std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
std::pair<int, int> GemmSPWarpPolicyNode::computeWarpPartition(int M, int N,
int block_size,
Target target,
bool use_wgmma,
int bits) const {
int num_warps = block_size / TargetGetWarpSize(target);
auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition(
auto [m_warp, n_warp] = GemmWarpPolicyNode::computeWarpPartition(
M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA);
// Special handling for gemm_sp when the tiling size is not a multiple
......@@ -84,26 +84,26 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
* @note An ICHECK failure is raised if a provided kPack is not 1 or 2.
*/
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<GemmSPNode> node = make_object<GemmSPNode>();
node->A = vmap[GetVarFromAccessPtr(args[0])];
node->E = vmap[GetVarFromAccessPtr(args[1])];
node->B = vmap[GetVarFromAccessPtr(args[2])];
node->C = vmap[GetVarFromAccessPtr(args[3])];
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->M = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value;
node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value();
ObjectPtr<GemmSPNode> node = tvm::ffi::make_object<GemmSPNode>();
node->a_ = vmap[GetVarFromAccessPtr(args[0])];
node->e_ = vmap[GetVarFromAccessPtr(args[1])];
node->b_ = vmap[GetVarFromAccessPtr(args[2])];
node->c_ = vmap[GetVarFromAccessPtr(args[3])];
node->transA_ = args[4].as<Bool>().value();
node->transB_ = args[5].as<Bool>().value();
node->m_ = args[6].as<IntImm>().value()->value;
node->n_ = args[7].as<IntImm>().value()->value;
node->k_ = args[8].as<IntImm>().value()->value;
node->policy_ = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clearAccum_ = args[10].as<Bool>().value();
if (args.size() > 11) {
node->kPack = args[11].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
node->kPack_ = args[11].as<IntImm>().value()->value;
if (node->kPack_ != 1 && node->kPack_ != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 12) {
node->wg_wait = args[12].as<IntImm>().value()->value;
node->wgWait_ = args[12].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
......@@ -118,7 +118,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
* @return TileOperator A TileOperator holding a cloned GemmSPNode.
*/
TileOperator GemmSPNode::Clone() const {
auto op = make_object<GemmSPNode>(*this);
auto op = tvm::ffi::make_object<GemmSPNode>(*this);
return GemmSP(op);
}
......@@ -144,37 +144,37 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
bool maybe_wgmma = TargetIsHopper(T.target) && (this->m_ >= 64) &&
(block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss";
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
(B.scope() == "shared" || B.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received " << A.scope()
<< " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
ICHECK((a_.scope() == "shared" || a_.scope() == "shared.dyn") &&
(b_.scope() == "shared" || b_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received "
<< a_.scope() << " and " << b_.scope();
ICHECK((e_.scope() == "shared" || e_.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implementation, found "
<< E.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
<< e_.scope();
ss << op_name << "<" << m_ << ", " << n_ << ", " << k_ << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
ss << transA_ << ", " << transB_;
ss << ", " << clearAccum_;
if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
if (wgWait_ != 0) {
ss << ", " << wgWait_;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E;
auto A_buffer = T.buffer_remap.count(a_) ? T.buffer_remap[a_] : a_;
auto B_buffer = T.buffer_remap.count(b_) ? T.buffer_remap[b_] : b_;
auto C_buffer = T.buffer_remap[c_];
auto E_buffer = T.buffer_remap.count(e_) ? T.buffer_remap[e_] : e_;
auto new_call =
Call(DataType::Handle(), tl::tl_gemm_sp(),
......@@ -217,59 +217,59 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
ICHECK(c_.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
if (TargetIsHopper(T.target)) {
const int warp_size = 32;
constexpr int wgmma_m = 16 * 4;
bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
auto fragment =
maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
(this->m_ >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, maybe_wgmma, a_->dtype.bits());
auto fragment = maybe_wgmma
? makeGemmFragmentCHopper(m_, n_, m_ / warp_m,
n_ / warp_n, c_->dtype.bits())
: makeGemmFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, a_->dtype.bits(),
transA_ ? 1 : 2));
} else {
ICHECK(false) << "Not implemented";
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B,
transB_ ? mat_continuous : mat_continuous / warp_n;
results.Set(b_,
makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
b_->dtype.bits(), transB_ ? 2 : 1));
} else {
ICHECK(false) << "WGMMA only support B in shared.";
}
} else if (TargetIsAmpere(T.target)) {
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, false, A->dtype.bits());
auto fragment =
makeGemmSparseFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
auto [warp_m, warp_n] = policy_->computeWarpPartition(
m_, n_, block_size, T.target, false, a_->dtype.bits());
auto fragment = makeGemmSparseFragmentC(m_, n_, m_ / warp_m, n_ / warp_n,
c_->dtype.bits());
results.Set(c_, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
A->dtype.bits()));
} else if (A.scope() == "local.fragment") {
if (a_.scope() == "shared" || a_.scope() == "shared.dyn") {
int dim_A = a_->shape.size();
const int64_t mat_stride = *as_const_int(a_->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(a_->shape[dim_A - 1]);
results.Set(a_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
a_->dtype.bits()));
} else if (a_.scope() == "local.fragment") {
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range));
......@@ -277,13 +277,13 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
B->dtype.bits()));
} else if (B.scope() == "local.fragment") {
if (b_.scope() == "shared" || b_.scope() == "shared.dyn") {
int dim_B = b_->shape.size();
const int64_t mat_stride = *as_const_int(b_->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(b_->shape[dim_B - 1]);
results.Set(b_, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
b_->dtype.bits()));
} else if (b_.scope() == "local.fragment") {
// auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range));
......@@ -303,7 +303,7 @@ TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK({ GemmSPNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -18,30 +18,32 @@ using namespace tir;
class GemmSPWarpPolicyNode : public GemmWarpPolicyNode {
public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
std::pair<int, int> computeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma,
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
};
class GemmSPWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSPWarpPolicy, ObjectRef,
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPWarpPolicy, ObjectRef,
GemmSPWarpPolicyNode);
explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmSPWarpPolicyNode>();
auto node = tvm::ffi::make_object<GemmSPWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
......@@ -51,19 +53,18 @@ public:
class GemmSPNode : public TileOperatorNode {
public:
tir::Buffer A, B, C, E;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
tir::Buffer a_, b_, c_, e_;
bool transA_, transB_;
int m_, n_, k_;
bool clearAccum_ = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
int kPack_ = 1;
int wgWait_ = 0;
mutable GemmSPWarpPolicy policy;
mutable GemmSPWarpPolicy policy_;
static constexpr const char *_type_key = "tl.GemmSP";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSP", GemmSPNode, TileOperatorNode);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
......@@ -73,44 +74,19 @@ public:
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy)
.def_ro("A", &GemmSPNode::A)
.def_ro("B", &GemmSPNode::B)
.def_ro("C", &GemmSPNode::C)
.def_ro("E", &GemmSPNode::E)
.def_ro("trans_A", &GemmSPNode::trans_A)
.def_ro("trans_B", &GemmSPNode::trans_B)
.def_ro("M", &GemmSPNode::M)
.def_ro("N", &GemmSPNode::N)
.def_ro("K", &GemmSPNode::K)
.def_ro("clear_accum", &GemmSPNode::clear_accum)
.def_ro("kPack", &GemmSPNode::kPack)
.def_ro("wg_wait", &GemmSPNode::wg_wait);
}
bool SEqualReduce(const GemmSPNode *other, SEqualReducer equal) const {
return equal(A, other->A) && equal(B, other->B) && equal(C, other->C) &&
equal(E, other->E) && equal(trans_A, other->trans_A) &&
equal(trans_B, other->trans_B) && equal(M, other->M) &&
equal(N, other->N) && equal(K, other->K) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(policy);
hash_reduce(A);
hash_reduce(B);
hash_reduce(C);
hash_reduce(E);
hash_reduce(trans_A);
hash_reduce(trans_B);
hash_reduce(M);
hash_reduce(N);
hash_reduce(K);
hash_reduce(clear_accum);
hash_reduce(kPack);
hash_reduce(wg_wait);
.def_ro("policy", &GemmSPNode::policy_)
.def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_)
.def_ro("e", &GemmSPNode::e_)
.def_ro("transA", &GemmSPNode::transA_)
.def_ro("transB", &GemmSPNode::transB_)
.def_ro("m", &GemmSPNode::m_)
.def_ro("n", &GemmSPNode::n_)
.def_ro("k", &GemmSPNode::k_)
.def_ro("clearAccum", &GemmSPNode::clearAccum_)
.def_ro("kPack", &GemmSPNode::kPack_)
.def_ro("wgWait", &GemmSPNode::wgWait_);
}
private:
......@@ -119,7 +95,7 @@ private:
class GemmSP : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
static const Op &Get();
};
......
......@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
using namespace tir;
......
......@@ -9,6 +9,8 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tl {
using namespace tir;
......@@ -33,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) {
const CallNode *call = args.as<CallNode>();
CHECK(call != nullptr);
const DataType &dtype = call->dtype;
ICHECK_EQ(dtype.lanes(), 1);
// NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype
if (dtype.is_float()) {
if (dtype.bits() == 64 || dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity(),
call->span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::infinity(), call->span);
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
throw; // Unreachable, keeps compiler happy
}
TVM_REGISTER_OP("tl.infinity")
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op);
} // namespace tl
} // namespace tvm
......@@ -55,7 +55,7 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap);
}
return TileOperator();
}
......@@ -77,7 +77,7 @@ Var GetVarFromAccessPtr(const PrimExpr &expr) {
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
return tvm::ffi::GetRef<Var>(var);
}
} // namespace tl
......
......@@ -39,7 +39,6 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
Array<Var> buffer_var_gemm;
};
struct LayoutInferArgs {
......@@ -62,14 +61,13 @@ public:
virtual TileOperator Clone() const = 0;
static constexpr const char *_type_key = "tl.TileOperator";
TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO("tl.TileOperator", TileOperatorNode, Object);
};
class TileOperator : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TileOperator, ObjectRef,
TileOperatorNode);
};
Var GetVarFromAccessPtr(const PrimExpr &expr);
......
......@@ -178,7 +178,7 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) {
}
TileOperator ParallelOpNode::Clone() const {
auto op = make_object<ParallelOpNode>(*this);
auto op = tvm::ffi::make_object<ParallelOpNode>(*this);
return ParallelOp(op);
}
......@@ -620,11 +620,37 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
if (IsCommonAccessIndice(buffer)) {
return loop_layout_;
}
// Prefer a simple path: if original 2D indices form a bijective map, invert
// them directly and avoid introducing a synthetic replicate dimension.
{
auto res2d =
arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
arith::IterMapLevel::Bijective,
const_cast<arith::Analyzer *>(&analyzer_));
if (res2d->errors.empty()) {
Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b2 =
loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
......@@ -642,7 +668,7 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
->CondenseReplicateVar();
}
TVM_FFI_STATIC_INIT_BLOCK({ ParallelOpNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
......@@ -66,8 +66,8 @@ public:
mutable Optional<PrimExpr> predicate_;
// Type key for TVM object system.
static constexpr const char *_type_key = "tl.ParallelOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ParallelOp", ParallelOpNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
......@@ -77,20 +77,6 @@ public:
.def_ro("predicate", &ParallelOpNode::predicate_);
}
bool SEqualReduce(const ParallelOpNode *other, SEqualReducer equal) const {
return equal(root_, other->root_) &&
equal(loop_layout_, other->loop_layout_) &&
equal(predicate_, other->predicate_);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root_);
hash_reduce(loop_layout_);
hash_reduce(predicate_);
}
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
// Construct from a root For loop.
ParallelOpNode(For root);
......@@ -150,10 +136,11 @@ private:
class ParallelOp : public TileOperator {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ParallelOp, TileOperator,
ParallelOpNode);
ParallelOp(const For &root) {
auto op = make_object<ParallelOpNode>(root);
auto op = tvm::ffi::make_object<ParallelOpNode>(root);
data_ = std::move(op);
}
};
......
......@@ -14,6 +14,7 @@
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -21,10 +22,54 @@ namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes (only tl.region)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = make_object<ReduceOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value;
node->dim = args[3].as<IntImm>().value()->value;
node->type = ReduceType(reduce_type);
......@@ -33,12 +78,12 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
}
TileOperator ReduceOpNode::Clone() const {
auto op = make_object<ReduceOpNode>(*this);
auto op = tvm::ffi::make_object<ReduceOpNode>(*this);
return ReduceOp(op);
}
TileOperator CumSumOpNode::Clone() const {
auto op = make_object<CumSumOpNode>(*this);
auto op = tvm::ffi::make_object<CumSumOpNode>(*this);
return CumSumOp(op);
}
......@@ -85,6 +130,7 @@ PrimExpr ReduceOpNode::MakeInitValue() const {
return make_zero(dst->dtype);
} else {
LOG(FATAL) << "Unsupported reduce type: " << type->type;
return PrimExpr();
}
}
......@@ -103,7 +149,7 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &lhs,
} else if (type->isMin()) {
return Min(lhs, rhs);
} else if (type->isAbsMax()) {
return Max(Max(lhs, rhs), -Min(lhs, rhs));
return Max(tvm::abs(lhs), tvm::abs(rhs));
} else if (type->isBitAnd()) {
return lhs & rhs;
} else if (type->isBitOr()) {
......@@ -359,70 +405,6 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return body;
}
auto is_shared_scope = [](const std::string &scope) {
return scope == "shared" || scope == "shared.dyn";
};
if (is_shared_scope(src_scope) && is_shared_scope(dst_scope)) {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
size_t src_dim = src_buffer->shape.size();
size_t dst_dim = dst_buffer->shape.size();
bool is_1d_reduce = (src_dim == dst_dim && dst_dim == 1);
if (!is_1d_reduce) {
ICHECK_EQ(src_dim, dst_dim + 1) << "Reduce dimension mismatch.";
} else {
ICHECK_EQ(dst_dim, 1U) << "Expect scalar layout for 1D reduce.";
}
auto thread_extent = as_const_int(T.thread_bounds->extent);
ICHECK(thread_extent)
<< "Shared-memory reduce requires static thread extent.";
int threads = *thread_extent;
if (TargetIsCuda(T.target)) {
ICHECK_EQ(threads % 32, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 32 on CUDA.";
} else if (TargetIsRocm(T.target)) {
ICHECK_EQ(threads % 64, 0)
<< "Shared reduce expects blockDim.x to be a multiple of 64 on HIP.";
}
bool use_abs = this->type->isAbsSum() || this->type->isAbsMax();
bool need_accumulate =
(!this->clear) && (this->type->isSum() || this->type->isAbsSum() ||
this->type->isBitAnd() || this->type->isBitOr() ||
this->type->isBitXor());
PrimExpr reduce_extent = src_buffer->shape[this->dim];
PrimExpr tail_extent = make_const(DataType::Int(32), 1);
for (size_t i = this->dim + 1; i < src_dim; ++i) {
tail_extent = analyzer->Simplify(tail_extent * src_buffer->shape[i]);
}
PrimExpr total_dest = make_const(DataType::Int(32), 1);
for (size_t i = 0; i < dst_dim; ++i) {
total_dest = analyzer->Simplify(total_dest * dst_buffer->shape[i]);
}
std::stringstream ss;
std::string reducer = this->MakeCodegenReducer();
ss << "tl::SharedReduceWarp<" << reducer << ", " << threads << ", "
<< (use_abs ? "true" : "false") << ", "
<< (need_accumulate ? "true" : "false") << ">::run";
Array<PrimExpr> call_args = {StringImm(ss.str()),
src_buffer.access_ptr(1),
dst_buffer.access_ptr(3),
cast(DataType::Int(32), total_dest),
cast(DataType::Int(32), reduce_extent),
cast(DataType::Int(32), tail_extent),
this->MakeInitValue()};
return Evaluate(Call(dst_buffer->dtype, builtin::call_extern(), call_args));
}
LOG(FATAL) << "Reduce for buffers in scope (" << src_scope << ", "
<< dst_scope << ") is not implemented.";
return Stmt();
......@@ -432,6 +414,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
......@@ -452,10 +435,40 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
}
auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
// Ensure the thread count is divisible by the replicate extent.
// Otherwise, we cannot infer a valid fragment<->fragment layout.
{
arith::Analyzer analyzer;
PrimExpr num_threads = T.thread_bounds->extent;
// Though the dest_buffer_rep_extent will be compressed at
// CondenseReplicateVar, we need to check the divisibility here to avoid
// the issue that the thread count is not divisible by the replicate
// extent.
if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) ==
0) &&
!analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) ==
0)) {
ICHECK(false) << "ReduceOp fragment layout inference failed: "
"num_threads % replicate_extent != 0. "
<< "This mapping requires the block's thread count to be "
"divisible by the "
<< "replicate extent. "
<< "Try one of: (1) choose a thread block size divisible "
"by replicate_extent; "
<< "(2) pick a different reduce dimension or adjust the "
"source fragment layout; "
<< "Details: num_threads=" << num_threads
<< ", replicate_extent=" << indice_rep_extent
<< ", src=" << src << ", dst=" << dst;
}
}
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
if (!T.layout_map.count(dst))
return {{dst, dst_layout}};
else {
......@@ -512,7 +525,7 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - dim: dimension to cumsum
/// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = make_object<CumSumOpNode>();
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->dim = args[2].as<IntImm>().value()->value;
......@@ -567,5 +580,12 @@ TIR_REGISTER_TL_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() {
ReduceOpNode::RegisterReflection();
CumSumOpNode::RegisterReflection();
ReduceTypeNode::RegisterReflection();
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment