"vscode:/vscode.git/clone" did not exist on "e2a7f03b80fc0e9e6a6f36acb43776509486a6d4"
Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
...@@ -23,6 +23,14 @@ public: ...@@ -23,6 +23,14 @@ public:
int bits) const; int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode, TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode); GemmWarpPolicyNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPWarpPolicyNode>()
.def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type)
.def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp)
.def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp);
}
}; };
class GemmSPWarpPolicy : public ObjectRef { class GemmSPWarpPolicy : public ObjectRef {
...@@ -53,6 +61,7 @@ public: ...@@ -53,6 +61,7 @@ public:
class GemmSPNode : public TileOperatorNode { class GemmSPNode : public TileOperatorNode {
public: public:
BufferRegion aRegion_, bRegion_, cRegion_, eRegion_;
tir::Buffer a_, b_, c_, e_; tir::Buffer a_, b_, c_, e_;
bool transA_, transB_; bool transA_, transB_;
int m_, n_, k_; int m_, n_, k_;
...@@ -75,6 +84,10 @@ public: ...@@ -75,6 +84,10 @@ public:
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>() refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy_) .def_ro("policy", &GemmSPNode::policy_)
.def_ro("aRegion", &GemmSPNode::aRegion_)
.def_ro("bRegion", &GemmSPNode::bRegion_)
.def_ro("cRegion", &GemmSPNode::cRegion_)
.def_ro("eRegion", &GemmSPNode::eRegion_)
.def_ro("a", &GemmSPNode::a_) .def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_) .def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_) .def_ro("c", &GemmSPNode::c_)
...@@ -96,7 +109,7 @@ private: ...@@ -96,7 +109,7 @@ private:
class GemmSP : public TileOperator { class GemmSP : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap); TVM_DLL GemmSP(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
/*!
* \file tl/op/gemm_sp_py.cc
* \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP)
* operators
*/
#include "gemm_sp_py.h"
#include "utils.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmSPPyNode with:
* - device pointers for A, E, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmSPPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmSPPy::GemmSPPy(Array<PrimExpr> args) {
ObjectPtr<GemmSPPyNode> node = tvm::ffi::make_object<GemmSPPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->A = node->aRegion_->buffer;
node->E = node->eRegion_->buffer;
node->B = node->bRegion_->buffer;
node->C = node->cRegion_->buffer;
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->trans_E = args[6].as<Bool>().value();
node->M = args[7].as<IntImm>().value()->value;
node->N = args[8].as<IntImm>().value()->value;
node->K = args[9].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
node->clear_accum = args[11].as<PrimExpr>().value();
node->stride_A = args[12].as<IntImm>().value()->value;
node->stride_B = args[13].as<IntImm>().value()->value;
node->offset_A = args[14].as<IntImm>().value()->value;
node->offset_B = args[15].as<IntImm>().value()->value;
if (args.size() > 16) {
node->kPack = args[16].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 17) {
node->wg_wait = args[17].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmSPPyNode as a TileOperator.
*
* Constructs a new GemmSPPyNode by copying the current node state and returns
* it wrapped in a GemmSPPy TileOperator.
*
* @return TileOperator A GemmSPPy operator that owns a copy of this node.
*/
TileOperator GemmSPPyNode::Clone() const {
auto op = tvm::ffi::make_object<GemmSPPyNode>(*this);
return GemmSPPy(op);
}
GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmSPPyNode::CheckWGMMA() const {
return false; // not supported yet
// if (B.scope() != "shared.dyn" && B.scope() != "shared") {
// return false;
// }
// if (C->dtype == DataType::Float(16)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Float(32)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::BFloat(16) &&
// B->dtype == DataType::BFloat(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::Float(32) && B->dtype ==
// DataType::Float(32))
// return (!trans_A) && trans_B && K % 8 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Int(32)) {
// if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else {
// return false;
// }
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_sp_py";
}
}
LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_sp_py";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_sp_py.h
* \brief Define gemm_sp_py operator.
*
*/
// TODO: @botbw: remove redundant code with gemm_py.h
#ifndef TVM_TL_OP_GEMM_SP_PY_H_
#define TVM_TL_OP_GEMM_SP_PY_H_
#include "gemm_sp.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSPPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, E, B, C;
// pointer to the A, E, B, C
BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
bool trans_A, trans_B, trans_E;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
// use GemmWarp Policy here as the atom size are flexible in v2
mutable GemmWarpPolicy policy;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPPyNode>()
.def_ro("A", &GemmSPPyNode::A)
.def_ro("E", &GemmSPPyNode::E)
.def_ro("B", &GemmSPPyNode::B)
.def_ro("C", &GemmSPPyNode::C)
.def_ro("aRegion", &GemmSPPyNode::aRegion_)
.def_ro("eRegion", &GemmSPPyNode::eRegion_)
.def_ro("bRegion", &GemmSPPyNode::bRegion_)
.def_ro("cRegion", &GemmSPPyNode::cRegion_)
.def_ro("trans_A", &GemmSPPyNode::trans_A)
.def_ro("trans_B", &GemmSPPyNode::trans_B)
.def_ro("trans_E", &GemmSPPyNode::trans_E)
.def_ro("M", &GemmSPPyNode::M)
.def_ro("N", &GemmSPPyNode::N)
.def_ro("K", &GemmSPPyNode::K)
.def_ro("stride_A", &GemmSPPyNode::stride_A)
.def_ro("stride_B", &GemmSPPyNode::stride_B)
.def_ro("offset_A", &GemmSPPyNode::offset_A)
.def_ro("offset_B", &GemmSPPyNode::offset_B)
.def_ro("clear_accum", &GemmSPPyNode::clear_accum)
.def_ro("kPack", &GemmSPPyNode::kPack)
.def_ro("wg_wait", &GemmSPPyNode::wg_wait)
.def_ro("policy", &GemmSPPyNode::policy);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
};
class GemmSPPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
GemmSPPyNode);
TVM_DLL GemmSPPy(Array<PrimExpr> args);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_SP_PY_H_
\ No newline at end of file
...@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int") ...@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)) Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int") .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op); .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) { PrimExpr infinity_op(PrimExpr args) {
...@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity") ...@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)) Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity") .set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op); .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -24,16 +24,14 @@ using namespace tir; ...@@ -24,16 +24,14 @@ using namespace tir;
* *
* @param call The TIR Call whose operator and arguments will be used to build * @param call The TIR Call whose operator and arguments will be used to build
* the TileOperator. * the TileOperator.
* @param vmap Buffer mapping passed through to the builder to resolve buffer
* references.
* @return TileOperator The constructed TileOperator, or a default (empty) * @return TileOperator The constructed TileOperator, or a default (empty)
* TileOperator if no builder exists. * TileOperator if no builder exists.
*/ */
TileOperator ParseOperator(Call call, BufferMap vmap) { TileOperator ParseOperator(Call call) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value(); Op op = call->op.as<Op>().value();
if (op_map.count(op)) { if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap); auto tile_op = op_map[op](call->args);
ICHECK(tile_op.defined()); ICHECK(tile_op.defined());
return tile_op; return tile_op;
} }
...@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { ...@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
* Otherwise returns a default-constructed (empty) TileOperator. * Otherwise returns a default-constructed (empty) TileOperator.
* *
* @param stmt TIR statement to inspect; expected to be an Evaluate of a Call. * @param stmt TIR statement to inspect; expected to be an Evaluate of a Call.
* @param vmap Mapping of buffer variables used when building the operator.
* @return TileOperator Parsed operator on success, or a default (empty) * @return TileOperator Parsed operator on success, or a default (empty)
* TileOperator if `stmt` is not an Evaluate(Call). * TileOperator if `stmt` is not an Evaluate(Call).
*/ */
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { TileOperator ParseOperator(Stmt stmt) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) { if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>(); auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap); return ParseOperator(tvm::ffi::GetRef<Call>(call));
} }
return TileOperator(); return TileOperator();
} }
......
...@@ -39,6 +39,9 @@ struct LowerArgs { ...@@ -39,6 +39,9 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
...@@ -48,6 +51,9 @@ struct LayoutInferArgs { ...@@ -48,6 +51,9 @@ struct LayoutInferArgs {
arith::Analyzer *analyzer; arith::Analyzer *analyzer;
bool buffer_oob = false; bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
}; };
class TileOperator; class TileOperator;
...@@ -72,23 +78,20 @@ public: ...@@ -72,23 +78,20 @@ public:
Var GetVarFromAccessPtr(const PrimExpr &expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
TileOperator ParseOperator(Call call, BufferMap vmap); TileOperator ParseOperator(Call call);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap); TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc = using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;
ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \
const Op &Entry::Get() { \ const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl.tileop." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl.tileop." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \ .set_attr<OpBuilderFunc>( \
[](Array<PrimExpr> args, BufferMap vmap) { \ "TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
return Entry(args, vmap); \
})
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const { ...@@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const {
return ParallelOp(op); return ParallelOp(op);
} }
void ParallelOpNode::ExpandLetBindings(
const Map<Var, PrimExpr> &let_var_to_expr) {
if (let_var_to_expr.empty())
return;
// Helper function to recursively find BufferLoads through let bindings
std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!indice_map_.count(bl->buffer)) {
indice_map_.Set(bl->buffer, bl->indices);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
expand(let_var_to_expr[var]);
}
}
});
};
// Scan all let bindings
for (const auto &[var, expr] : let_var_to_expr) {
expand(expr);
}
}
Stmt ParallelOpNode::Lower(const LowerArgs &T, Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
return root_; return root_;
...@@ -214,6 +242,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -214,6 +242,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
// Expand let bindings to find fragment buffer accesses
if (!T.let_var_to_expr.empty()) {
const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
}
if (level == InferLevel::kStrict) { if (level == InferLevel::kStrict) {
LayoutMap results; LayoutMap results;
// Deduce buffers that should be complicated replicated. // Deduce buffers that should be complicated replicated.
...@@ -252,17 +286,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -252,17 +286,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
forward_vars.push_back( forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar)); IterVar(Range(0, s), Var(), IterVarType::kDataPar));
} }
Array<PrimExpr> forward_index;
for (const auto &iv : forward_vars) {
forward_index.push_back(iv->var);
}
Var rep; Var rep;
auto rep_iter = auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar); IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
// Use default fragment indexing (single output dim) to
// stay consistent with other ops (e.g., ReduceOp), and
// bind the thread range for comparability.
const PrimExpr &forward_thread = rep; const PrimExpr &forward_thread = rep;
results.Set(buffer, Fragment(forward_vars, forward_index, auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread,
forward_thread, rep_iter)); rep_iter)
->BindThreadRange(T.thread_bounds);
results.Set(buffer, frag);
} }
} }
return results; return results;
...@@ -452,8 +487,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -452,8 +487,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1; PrimExpr loop_total_size = 1;
...@@ -562,6 +596,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -562,6 +596,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} else { } else {
return {}; return {};
} }
// check loop_layout_ is injective
auto injective_res = loop_layout_->DetectInjective();
if (!injective_res->errors.empty()) {
std::ostringstream oss;
oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
<< '\n'
<< " errors: " << injective_res->errors << '\n'
<< " loop AST: " << root_;
throw LoopLayoutInjectiveException(oss.str());
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
......
...@@ -24,15 +24,6 @@ namespace tl { ...@@ -24,15 +24,6 @@ namespace tl {
using namespace tir; using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices, Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices, Array<PrimExpr> large_frag_indices,
...@@ -114,6 +105,10 @@ private: ...@@ -114,6 +105,10 @@ private:
void AddPredicate(const PrimExpr &expr) const { void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
// Expand let bindings to find fragment buffer accesses and add them to
// indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0],
// ...)
void ExpandLetBindings(const Map<Var, PrimExpr> &let_var_to_expr);
// Allow ParallelLoopNestVisitor to access private members. // Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor; friend class ParallelLoopNestVisitor;
......
...@@ -14,60 +14,24 @@ ...@@ -14,60 +14,24 @@
#include "../op/parallel.h" #include "../op/parallel.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h"
#include "utils.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region) // NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: Call nodes (only tl.region) ReduceOp::ReduceOp(Array<PrimExpr> args) {
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>(); ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst // Accept BufferRegion/BufferLoad for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap); node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap); node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer; node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer; node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value; std::string reduce_type = args[2].as<StringImm>().value()->value;
...@@ -231,6 +195,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -231,6 +195,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_scope = this->dst.scope(); auto dst_scope = this->dst.scope();
if (src_scope == "local.fragment" && dst_scope == "local.fragment") { if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
Buffer src_buffer = get_buffer(this->src); Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst); Buffer dst_buffer = get_buffer(this->dst);
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value(); Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
...@@ -513,12 +478,22 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -513,12 +478,22 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(ReduceOp, reduce) TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { // Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
CumSumOp::CumSumOp(Array<PrimExpr> args) {
/// CumSum constructor arguments: /// CumSum constructor arguments:
/// - src: input buffer /// - src: input buffer
/// - dst: output buffer /// - dst: output buffer
...@@ -526,11 +501,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -526,11 +501,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - reverse: whether to cumsum in reverse order /// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4); CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>(); ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])]; // node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])]; // node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value; node->dim = args[2].as<IntImm>().value()->value;
node->reverse = args[3].as<Bool>().value(); node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size())); CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
<< "The dim of cumsum should be less than the number of dimensions. Got "
"dim="
<< node->dim << ", but src has " << node->src->shape.size() << " dims.";
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -545,19 +528,29 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -545,19 +528,29 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss; std::stringstream ss;
auto threads = T.thread_bounds->extent; auto threads = T.thread_bounds->extent;
Array<PrimExpr> args; Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size());
// Build access pointers from regions locally
PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);
// Use region extents instead of buffer shape for correct slice handling
Array<PrimExpr> src_extents;
for (const auto &range : srcRegion_->region) {
src_extents.push_back(range->extent);
}
int ndim = static_cast<int>(src_extents.size());
if (ndim == 1) { if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim " ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0."; "= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false") ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run"; << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0]};
src->shape[0]};
} else if (ndim == 2) { } else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", " ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run"; << (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3), args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0],
src->shape[0], src->shape[1]}; src_extents[1]};
} else { } else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got " LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D."; << ndim << "D.";
...@@ -576,7 +569,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -576,7 +569,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
return {}; return {};
} }
TIR_REGISTER_TL_OP(CumSumOp, cumsum) TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -125,7 +125,7 @@ class ReduceOp : public TileOperator { ...@@ -125,7 +125,7 @@ class ReduceOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode); ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL ReduceOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
...@@ -133,8 +133,10 @@ public: ...@@ -133,8 +133,10 @@ public:
class CumSumOpNode : public TileOperatorNode { class CumSumOpNode : public TileOperatorNode {
public: public:
tir::Buffer src, dst; ///< Source and destination buffers tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum // Optional: keep the original regions used to construct this op
bool reverse; ///< Whether to compute in reverse order BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -143,6 +145,8 @@ public: ...@@ -143,6 +145,8 @@ public:
refl::ObjectDef<CumSumOpNode>() refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src) .def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst) .def_ro("dst", &CumSumOpNode::dst)
.def_ro("srcRegion", &CumSumOpNode::srcRegion_)
.def_ro("dstRegion", &CumSumOpNode::dstRegion_)
.def_ro("dim", &CumSumOpNode::dim) .def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse); .def_ro("reverse", &CumSumOpNode::reverse);
} }
...@@ -159,7 +163,7 @@ class CumSumOp : public TileOperator { ...@@ -159,7 +163,7 @@ class CumSumOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode); CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap); TVM_DLL CumSumOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
/*! /*!
* \file tl/op/region.cc * \file tl/op/region.cc
* \brief Define region operator. * \brief Define region operator (bridge to carry BufferRegion via Call args).
* *
* Notes:
* - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane
* count. Dynamic extents like (H1 - H0) cannot be encoded as
* Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the
* explicit extent information.
* - tl.region carries both mins and extents in Call args and lets the backend
* reconstruct a BufferRegion faithfully.
*/ */
#include "region.h" #include "region.h"
...@@ -11,27 +18,7 @@ namespace tvm { ...@@ -11,27 +18,7 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
/** RegionOp::RegionOp(Array<PrimExpr> args) {
* @brief Construct a RegionOp from TL operator arguments.
*
* Parses the TL `region` operator call arguments to populate the RegionOpNode:
* - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension
* minima.
* - args[1] must be a constant integer used as the access mask.
* - args[2 + i] provides the extent for dimension `i`.
*
* The constructor validates that the number of load indices equals `args.size()
* - 2` and will abort via ICHECK on mismatch or if args[0] is not a
* `BufferLoad`.
*
* Parameters:
* - args: TL operator call arguments in the form
* [BufferLoad(min_i...), access_mask, extent_0, extent_1, ...,
* extent_{n-1}] where n = number of dimensions.
* - vmap: BufferMap passed through by the caller (not documented here as a
* generic utility).
*/
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t n = args.size(); size_t n = args.size();
size_t ndim = n - 2; size_t ndim = n - 2;
auto load = args[0].as<BufferLoadNode>(); auto load = args[0].as<BufferLoadNode>();
...@@ -39,10 +26,24 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -39,10 +26,24 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
ICHECK(load->indices.size() == ndim) ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim; << "load->indices.size() = " << load->indices << " ndim = " << ndim;
Array<Range> ranges; Array<Range> ranges;
// Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents
for (size_t i = 0; i < ndim; i++) { for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i]; PrimExpr index = load->indices[i];
PrimExpr extent = args[2 + i]; PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent)); if (const auto *ramp = index.as<RampNode>()) {
const auto *stride_imm = ramp->stride.as<IntImmNode>();
ICHECK(stride_imm && stride_imm->value == 1)
<< "RegionOp expects stride-1 Ramp for index";
if (const auto *lanes_imm = ramp->lanes.as<IntImmNode>()) {
if (const auto *ext_imm = extent.as<IntImmNode>()) {
ICHECK_EQ(lanes_imm->value, ext_imm->value)
<< "Ramp lanes and provided extent must match";
}
}
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, extent));
}
} }
ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>(); ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer; node->buffer_ = load->buffer;
...@@ -51,26 +52,11 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -51,26 +52,11 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node); data_ = std::move(node);
} }
/**
* @brief Create a copy of this RegionOpNode and return it as a TileOperator.
*
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const { TileOperator RegionOpNode::Clone() const {
auto op = tvm::ffi::make_object<RegionOpNode>(*this); auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op); return RegionOp(op);
} }
/**
* @brief Check whether the region spans the entire underlying buffer.
*
* Returns true if for every dimension the range minimum is zero and the
* range extent is structurally equal to the corresponding buffer shape
* dimension. Otherwise returns false.
*
* @return true if the region covers the full buffer in all dimensions; false
* otherwise.
*/
bool RegionOpNode::IsFullRegion() const { bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) { for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) if (!is_zero(ranges_[i]->min))
...@@ -81,39 +67,16 @@ bool RegionOpNode::IsFullRegion() const { ...@@ -81,39 +67,16 @@ bool RegionOpNode::IsFullRegion() const {
return true; return true;
} }
/**
* @brief Lower the region operator to a TIR statement.
*
* Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's
* evaluation path (currently `Evaluate(0)`).
*
* @param T Lowering context (provides buffers, producers/consumers and other
* environment required for lowering).
* @param analyzer Optional arithmetic analyzer used for simplification during
* lowering.
* @return Stmt The lowered TIR statement representing this region operation.
*/
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0); return Evaluate(0);
} }
/**
* @brief Infers data layout for the region operator.
*
* This operator does not provide any layout inference; the function always
* returns an empty LayoutMap regardless of the provided arguments or inference
* level.
*
* @param T Layout inference arguments (ignored).
* @param level Inference granularity level (ignored).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
return {}; return {};
} }
TIR_REGISTER_TL_OP(RegionOp, region) TIR_REGISTER_TL_TILE_OP(RegionOp, region)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure)); Integer(CallEffectKind::kPure));
......
/*! /*!
* \file tl/op/op.h * \file tl/op/region.h
* \brief Tile library operations. * \brief Tile memory region descriptor op (bridge to carry BufferRegion via
* Call args).
* *
* Why tl.region instead of passing BufferRegion directly?
*
* - While TIR can represent a BufferRegion, when a BufferRegion is passed as a
* call argument through call_intrin/FFI, the Python->C++ conversion lowers it
* to a BufferLoad(indices). To encode an interval inside indices, the FFI
* typically uses Ramp(base, stride, lanes) to represent a contiguous slice.
* - Ramp(lanes) may only be a constant or vscale*k (scalable vector). A general
* PrimExpr (e.g., H1 - H0) is not allowed as lanes, so dynamic extents would
* make the lowered BufferLoad invalid.
* - Moreover, BufferLoad only carries indices, not per-axis extents. Downstream
* tile operators (e.g., tl.copy, tl.reduce) that require both min and extent
* cannot losslessly recover dynamic extents from a BufferLoad alone.
*
* tl.region is a small transport-only op that solves this:
* - The frontend packs buffer + mins (from BufferLoad.indices) + extents into
* Call args, allowing dynamic extents to be expressed explicitly.
* - The backend (NormalizeToBufferRegion) reconstructs a BufferRegion from the
* tl.region call without losing information.
* - The op itself carries no semantics in Lower/InferLayout and is only used as
* a bridge for argument passing.
*/ */
#ifndef TVM_TL_OP_REGION_H_ #ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_ #define TVM_TL_OP_REGION_H_
#include "./operator.h" #include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
/**
* Tile operator representing a memory region (buffer + ranges) used by TL
* passes.
*
* Encapsulates the target tir::Buffer, the region extents as an Array<Range>,
* and an access mask that indicates permitted or intended accesses for lowering
* and layout inference.
*/
/**
* Lower this RegionOp into a TIR statement representing the region access.
*
* @param T Lowering-time arguments (e.g., loop/build context and value
* mappings).
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A tir::Stmt that implements the region access/mutation described by
* this operator.
*/
/**
* Infer the layout mapping for this region operator.
*
* Produces a LayoutMap describing how loop/axis indices map to buffer axes for
* layout-aware scheduling and subsequent operators.
*
* @param T Layout inference arguments (e.g., input layouts and shapes).
* @param level The inference detail level to use.
* @return A LayoutMap describing inferred mappings for the operator.
*/
/**
* Return true when this RegionOp represents the full buffer region (i.e.,
* ranges cover the entire buffer extent).
*/
/**
* Create a shallow copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a cloned RegionOpNode.
*/
/**
* Construct a RegionOp from argument expressions and a buffer map.
*
* @param args Positional expressions used to instantiate the operator
* (semantics depend on how RegionOp is invoked in TL pipelines).
* @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used
* during creation.
*/
/**
* Return the global Op registration for RegionOp.
*
* @return Reference to the registered tvm::Op describing the RegionOp.
*/
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -80,6 +42,12 @@ public: ...@@ -80,6 +42,12 @@ public:
Array<Range> ranges_; Array<Range> ranges_;
int access_mask_; int access_mask_;
/*!
* access_mask_ encodes the intended access type when the region is used as
* an argument to tile operators: 1=read, 2=write, 3=read-write. The mask is
* transport metadata only and does not affect lowering.
*/
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode, TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TileOperatorNode); TileOperatorNode);
...@@ -107,8 +75,13 @@ class RegionOp : public TileOperator { ...@@ -107,8 +75,13 @@ class RegionOp : public TileOperator {
public: public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator, TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode); RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap); /*!
* Build a RegionOp from call arguments:
* - args[0]: BufferLoad whose indices are per-axis minima.
* - args[1]: Integer access mask (1=r, 2=w, 3=rw).
* - args[2 + i]: Extent of axis i (supports dynamic PrimExpr).
*/
TVM_DLL RegionOp(Array<PrimExpr> args);
static const Op &Get(); static const Op &Get();
}; };
......
...@@ -15,16 +15,19 @@ using runtime::DataType; ...@@ -15,16 +15,19 @@ using runtime::DataType;
struct TCGEN5MMAMeta { struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k; int atom_m, atom_n, atom_k;
bool enable_ws, enable_2cta;
}; };
inline std::pair<bool, TCGEN5MMAMeta> inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \ #define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \ return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
}
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
} }
std::vector<int> ws_valid_atom_ns = {256, 128, 64}; std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
...@@ -34,39 +37,50 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { ...@@ -34,39 +37,50 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
if (M % 128 == 0) { if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16) for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(128, atom_n, 16); SUCCESS(128, atom_n, 16, false, false);
FAIL; FAIL;
} else if (M % 64 == 0) { } else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(64, atom_n, 16); SUCCESS(64, atom_n, 16, true, false);
FAIL; FAIL;
} else if (M % 32 == 0) { } else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(32, atom_n, 16); SUCCESS(32, atom_n, 16, true, false);
FAIL; FAIL;
} else { } else {
FAIL; FAIL;
} }
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && } else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() ||
(c_dtype.is_float() && c_dtype.bits() == 32)) { ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0) if (K % 32 != 0)
FAIL; FAIL;
if (M % 128 == 0) { if (M % 128 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 16; atom_n -= 16) for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(128, atom_n, 32); SUCCESS(128, atom_n, 32, false, true);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
FAIL; FAIL;
} else if (M % 64 == 0) { } else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(64, atom_n, 32); SUCCESS(64, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32, false, false);
FAIL; FAIL;
} else if (M % 32 == 0) { } else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns) for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0) if (N % atom_n == 0)
SUCCESS(32, atom_n, 32); SUCCESS(32, atom_n, 32, true, false);
FAIL; FAIL;
} else { } else {
FAIL; FAIL;
......
/*!
* \file tl/op/utils.cc
* \brief Common utilities implementation for TL ops.
*/
#include "utils.h"
#include <tvm/tir/builtin.h>
namespace tvm {
namespace tl {
using namespace tir;
BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: tl.region(...) — reconstruct via RegionOp (bridge)
if (const auto *call = arg.as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
LOG(FATAL) << "Unsupported argument for BufferRegion (expect "
"BufferLoad/BufferRegion/tl.region): "
<< arg;
}
LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg;
throw; // Unreachable
}
PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, int rw_mask,
bool require_2d) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
if (require_2d) {
ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims";
}
PrimExpr offset, extent;
if (ndim == 1) {
// 1D: straightforward
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/utils.h
* \brief Common utilities for TL ops.
*/
#ifndef TVM_TL_OP_UTILS_H_
#define TVM_TL_OP_UTILS_H_
#include "./operator.h"
#include "region.h"
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so ops can uniformly consume regions.
// Note: tvm_access_ptr is no longer supported here.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg);
// Build a tvm_access_ptr(handle) from a BufferRegion.
// - If `require_2d` is true, checks buffer ndim >= 2.
// - For 1D regions (when allowed), offset=min, extent=extent.
// - For ndim >= 2, offset sums all but last two dims using row-major strides,
// extent is product of the last two extents.
TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask, bool require_2d = false);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_UTILS_H_
/*
* Helper functions for nicer runtime error messages.
*/
#include "error_helpers.h"
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <sstream>
#include <string>
namespace tvm {
namespace tl {
// Return non-zero so that tvm_call_packed sites treat it as failure and return
// -1.
static int DTypeMismatch(const tvm::ffi::String &kernel_name,
const tvm::ffi::String &buffer_name,
int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << "kernel " << std::string(kernel_name) << " input "
<< std::string(buffer_name) << " dtype expected " << expect << ", but got "
<< actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
// Variant without names, to avoid passing extra raw strings through packed
// args.
static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << "dtype mismatch: expected " << expect << ", but got " << actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
// Register packed versions, following the design in runtime.cc
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
// actual_code, actual_bits, actual_lanes,
// expect_code, expect_bits, expect_lanes)
refl::GlobalDef().def_packed(
tl::tvm_error_dtype_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, "
"actual_code, actual_bits, actual_lanes, "
<< "expect_code, expect_bits, expect_lanes";
auto kernel_name = args[0].cast<tvm::ffi::String>();
auto buffer_name = args[1].cast<tvm::ffi::String>();
int64_t actual_code = args[2].cast<int64_t>();
int64_t actual_bits = args[3].cast<int64_t>();
int64_t actual_lanes = args[4].cast<int64_t>();
int64_t expect_code = args[5].cast<int64_t>();
int64_t expect_bits = args[6].cast<int64_t>();
int64_t expect_lanes = args[7].cast<int64_t>();
// Reuse the helper to format the message
(void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits,
actual_lanes, expect_code, expect_bits,
expect_lanes);
// Provide a return value for completeness, then signal the error
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_ndim_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " ndim expected " << expect << ", but got "
<< got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_byte_offset_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " byte_offset expected " << expect
<< ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_device_type_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
const char *expect_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(expect));
const char *got_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(got));
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " device_type expected " << expect_str
<< ", but got " << got_str;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String
refl::GlobalDef().def_packed(
tl::tvm_error_null_ptr,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3)
<< "__tvm_error_null_ptr(kernel, buffer, field)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " expected non-NULL, but got NULL";
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_expect_eq,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 5)
<< "__tvm_error_expect_eq(kernel, buffer, field, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
int64_t expect = args[3].cast<int64_t>();
int64_t got = args[4].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field) << " expected "
<< expect << ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String [, reason:String]
refl::GlobalDef().def_packed(
tl::tvm_error_constraint_violation,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3 || args.size() == 4)
<< "__tvm_error_constraint_violation(kernel, buffer, field[, "
"reason])";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::string reason;
if (args.size() == 4) {
reason = args[3].cast<tvm::ffi::String>();
}
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " constraint not satisfied";
if (!reason.empty()) {
os << ": " << reason;
}
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// Legacy typed registrations for backward compatibility
refl::GlobalDef().def("tilelang_error_dtype_mismatch",
&tvm::tl::DTypeMismatch);
refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
&tvm::tl::DTypeMismatchNoNames);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/runtime/error_helpers.h
* \brief Error helper FFI names for TileLang runtime.
*/
#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_
#define TVM_TL_RUNTIME_ERROR_HELPERS_H_
namespace tvm {
namespace tl {
// Error helper packed functions
constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch";
constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch";
constexpr const char *tvm_error_byte_offset_mismatch =
"__tvm_error_byte_offset_mismatch";
constexpr const char *tvm_error_device_type_mismatch =
"__tvm_error_device_type_mismatch";
constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr";
constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq";
constexpr const char *tvm_error_constraint_violation =
"__tvm_error_constraint_violation";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_
...@@ -13,6 +13,12 @@ ...@@ -13,6 +13,12 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
#if 1
// Thread-local storage for restoring the L2 persisting cache limit
static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
#endif
#if (CUDA_MAJOR_VERSION >= 12) #if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) { template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
...@@ -91,19 +97,21 @@ struct TensorMapArgs { ...@@ -91,19 +97,21 @@ struct TensorMapArgs {
// set device api // set device api
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args, // Register using the canonical names defined in runtime.h
Any *ret) { refl::GlobalDef().def_packed(
TensorMapArgs T = TensorMapArgs::Extract(args); tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) {
CUresult result = cuTensorMapEncodeTiled( TensorMapArgs T = TensorMapArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, CUresult result = cuTensorMapEncodeTiled(
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.l2Promotion, T.oobFill); T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
if (result != CUDA_SUCCESS) { T.swizzle, T.l2Promotion, T.oobFill);
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n' if (result != CUDA_SUCCESS) {
<< T.ToDebugString(); LOG_FATAL << "Failed to initialize the TMA descriptor " << result
} << '\n'
*ret = static_cast<int>(result); << T.ToDebugString();
}); }
*ret = static_cast<int>(result);
});
} }
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
...@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs { ...@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs {
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed( refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) { tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col( CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
...@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() { ...@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() {
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
//
// CUDA L2 Persisting Cache Access Policy Window helpers.
// Exposed as TVM FFI packed functions similar to TMA initialization.
//
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Set stream access policy window and adjust persisting L2 cache size
// Args:
// [0]: void* base_ptr (required)
// [1]: int64 num_bytes (required)
// [2]: float hit_ratio (optional, default 0.8)
// [3]: void* stream (optional, default 0 => default stream)
// [4]: int64 l2_limit_bytes (optional, default = num_bytes)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_set_access_policy_window,
[](PackedArgs args, Any *ret) {
ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes";
void *base_ptr = args[0].cast<void *>();
size_t num_bytes = static_cast<size_t>(args[1].cast<int64_t>());
float hit_ratio = 0.8f;
if (args.size() >= 3) {
// Accept double/float
hit_ratio = static_cast<float>(args[2].cast<double>());
}
CUstream stream = nullptr;
if (args.size() >= 4) {
stream = reinterpret_cast<CUstream>(args[3].cast<void *>());
}
size_t l2_limit_bytes = num_bytes;
if (args.size() >= 5) {
l2_limit_bytes = static_cast<size_t>(args[4].cast<int64_t>());
}
// Clamp requested limit to device capability
CUdevice device;
CUresult result = cuCtxGetDevice(&device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current CUDA device: " << result;
}
int max_persisting = 0;
result = cuDeviceGetAttribute(
&max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE,
device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: "
<< result;
}
if (max_persisting > 0 &&
l2_limit_bytes > static_cast<size_t>(max_persisting)) {
l2_limit_bytes = static_cast<size_t>(max_persisting);
}
// Save current limit to restore later
size_t init_persisting_l2_cache_size = 0;
result = cuCtxGetLimit(&init_persisting_l2_cache_size,
CU_LIMIT_PERSISTING_L2_CACHE_SIZE);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size;
__tl_prev_persisting_l2_cache_saved = true;
// Set new limit
result =
cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set persisting L2 cache size limit: "
<< result;
}
// Apply access policy window to stream
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
stream_attribute.accessPolicyWindow.base_ptr = base_ptr;
stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes;
stream_attribute.accessPolicyWindow.hitRatio = hit_ratio;
stream_attribute.accessPolicyWindow.hitProp =
CU_ACCESS_PROPERTY_PERSISTING;
stream_attribute.accessPolicyWindow.missProp =
CU_ACCESS_PROPERTY_STREAMING;
result = cuStreamSetAttribute(stream,
CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set stream access policy window: " << result;
}
*ret = static_cast<int>(result);
});
// Reset stream access policy window and restore the previous L2 cache size
// Args:
// [0]: void* stream (optional, default 0)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_reset_access_policy_window,
[](PackedArgs args, Any *ret) {
CUstream stream = nullptr;
if (args.size() >= 1) {
stream = reinterpret_cast<CUstream>(args[0].cast<void *>());
}
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
// num_bytes = 0 disables the access policy window on the stream
stream_attribute.accessPolicyWindow.num_bytes = 0;
CUresult result = cuStreamSetAttribute(
stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset stream access policy window: "
<< result;
}
result = cuCtxResetPersistingL2Cache();
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result;
}
if (__tl_prev_persisting_l2_cache_saved) {
result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE,
__tl_prev_persisting_l2_cache_size);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to restore persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_saved = false;
}
*ret = static_cast<int>(result);
});
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled = ...@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled =
constexpr const char *tvm_tensormap_create_im2col = constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_im2col";
#endif // (CUDA_MAJOR_VERSION >= 12) #endif // (CUDA_MAJOR_VERSION >= 12)
// CUDA stream access policy window helpers
constexpr const char *tvm_cuda_stream_set_access_policy_window =
"__tvm_cuda_stream_set_access_policy_window";
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
"__tvm_cuda_stream_reset_access_policy_window";
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_ #endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <tvm/ffi/cast.h> #include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h> #include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h> #include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h> #include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h> #include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h> #include <tvm/ffi/string.h>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment