Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -23,6 +23,14 @@ public:
int bits) const;
TVM_FFI_DECLARE_OBJECT_INFO("tl.GemmSPWarpPolicy", GemmSPWarpPolicyNode,
GemmWarpPolicyNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPWarpPolicyNode>()
.def_ro("policy_type", &GemmSPWarpPolicyNode::policy_type)
.def_ro("m_warp", &GemmSPWarpPolicyNode::m_warp)
.def_ro("n_warp", &GemmSPWarpPolicyNode::n_warp);
}
};
class GemmSPWarpPolicy : public ObjectRef {
......@@ -53,6 +61,7 @@ public:
class GemmSPNode : public TileOperatorNode {
public:
BufferRegion aRegion_, bRegion_, cRegion_, eRegion_;
tir::Buffer a_, b_, c_, e_;
bool transA_, transB_;
int m_, n_, k_;
......@@ -75,6 +84,10 @@ public:
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPNode>()
.def_ro("policy", &GemmSPNode::policy_)
.def_ro("aRegion", &GemmSPNode::aRegion_)
.def_ro("bRegion", &GemmSPNode::bRegion_)
.def_ro("cRegion", &GemmSPNode::cRegion_)
.def_ro("eRegion", &GemmSPNode::eRegion_)
.def_ro("a", &GemmSPNode::a_)
.def_ro("b", &GemmSPNode::b_)
.def_ro("c", &GemmSPNode::c_)
......@@ -96,7 +109,7 @@ private:
class GemmSP : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode);
TVM_DLL GemmSP(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL GemmSP(Array<PrimExpr> args);
static const Op &Get();
};
......
/*!
* \file tl/op/gemm_sp_py.cc
* \brief Implementation of Sparse General Matrix Multiplication (GEMM_SP)
* operators
*/
#include "gemm_sp_py.h"
#include "utils.h"
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "tvm/ffi/string.h"
namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
*
* This constructor deserializes operator parameters from `args` and resolves
* buffer references via `vmap`, populating an internal GemmSPPyNode with:
* - device pointers for A, E, B, C and their corresponding Buffer objects,
* - transpose flags for A and B,
* - matrix dimensions M, N, K,
* - warp allocation policy and clear_accum flag,
* - strides and memory offsets for A and B,
* - optional kPack (must be 1 or 2) and optional wg_wait.
*
* The populated GemmSPPyNode is stored into the wrapper's internal `data_`.
*
* @param args Positional serialized arguments produced by the TL frontend:
* expected layout is:
* [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
* M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
* stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
* (optional) kPack (Int), (optional) wg_wait (Int)]
* @param vmap Mapping from access pointer vars to Buffer objects used to
* resolve the Buffer corresponding to each pointer argument.
*
* @note If `kPack` is provided it must be 1 or 2; otherwise the constructor
* fails with an ICHECK (runtime assertion). No other validation is
* performed here.
*/
GemmSPPy::GemmSPPy(Array<PrimExpr> args) {
ObjectPtr<GemmSPPyNode> node = tvm::ffi::make_object<GemmSPPyNode>();
node->aRegion_ = NormalizeToBufferRegion(args[0]);
node->eRegion_ = NormalizeToBufferRegion(args[1]);
node->bRegion_ = NormalizeToBufferRegion(args[2]);
node->cRegion_ = NormalizeToBufferRegion(args[3]);
node->A = node->aRegion_->buffer;
node->E = node->eRegion_->buffer;
node->B = node->bRegion_->buffer;
node->C = node->cRegion_->buffer;
node->trans_A = args[4].as<Bool>().value();
node->trans_B = args[5].as<Bool>().value();
node->trans_E = args[6].as<Bool>().value();
node->M = args[7].as<IntImm>().value()->value;
node->N = args[8].as<IntImm>().value()->value;
node->K = args[9].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
node->clear_accum = args[11].as<PrimExpr>().value();
node->stride_A = args[12].as<IntImm>().value()->value;
node->stride_B = args[13].as<IntImm>().value()->value;
node->offset_A = args[14].as<IntImm>().value()->value;
node->offset_B = args[15].as<IntImm>().value()->value;
if (args.size() > 16) {
node->kPack = args[16].as<IntImm>().value()->value;
if (node->kPack != 1 && node->kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 17) {
node->wg_wait = args[17].as<IntImm>().value()->value;
}
data_ = std::move(node);
}
/**
* @brief Create a copy of this GemmSPPyNode as a TileOperator.
*
* Constructs a new GemmSPPyNode by copying the current node state and returns
* it wrapped in a GemmSPPy TileOperator.
*
* @return TileOperator A GemmSPPy operator that owns a copy of this node.
*/
TileOperator GemmSPPyNode::Clone() const {
auto op = tvm::ffi::make_object<GemmSPPyNode>(*this);
return GemmSPPy(op);
}
GemmInst GemmSPPyNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
/**
* @brief Checks whether WGMMA (warp-group MMA) can be used for this GEMM.
*
* Evaluates device-memory placement, data-type combinations, transpose flags,
* and K divisibility constraints required for the Hopper WGMMA code path.
*
* The check returns true only when:
* - B resides in shared memory ("shared" or "shared.dyn"); and
* - (C, A, B) dtypes match one of the supported combinations below and K
* satisfies the required alignment; and
* - for combinations that require specific orientations, A is not transposed
* and B is transposed.
*
* Supported combinations and constraints:
* - C=float16:
* - A=float16, B=float16: K % 16 == 0
* - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K %
* 32 == 0
* - C=float32:
* - A=float16, B=float16: K % 16 == 0
* - A=bfloat16, B=bfloat16: K % 16 == 0
* - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0
* - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0
* - C=int32:
* - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B)
* and K % 32 == 0
*
* @return true if WGMMA is supported for the current buffers, dtypes, and
* transpose/shape constraints; false otherwise.
*/
bool GemmSPPyNode::CheckWGMMA() const {
return false; // not supported yet
// if (B.scope() != "shared.dyn" && B.scope() != "shared") {
// return false;
// }
// if (C->dtype == DataType::Float(16)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Float(32)) {
// if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::BFloat(16) &&
// B->dtype == DataType::BFloat(16))
// return K % 16 == 0;
// else if (A->dtype == DataType::Float(32) && B->dtype ==
// DataType::Float(32))
// return (!trans_A) && trans_B && K % 8 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e4m3() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e4m3())
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype.is_float8_e5m2() && B->dtype.is_float8_e5m2())
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else if (C->dtype == DataType::Int(32)) {
// if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
// return (!trans_A) && trans_B && K % 32 == 0;
// else
// return false;
// } else {
// return false;
// }
}
/**
* @brief Parse and return the numeric GPU architecture from a Target's "arch"
* attribute.
*
* Examines the target's "arch" string and, if it matches the pattern
* "sm_<num>", returns <num> as an int. If the attribute is present but does not
* match that pattern, returns 0.
*
* Preconditions: the target must have an "arch" attribute (this is checked via
* ICHECK).
*
* @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if
* the arch string does not match "sm_<num>".
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
ICHECK(s.has_value());
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0) {
arch_int = std::stoi(arch.substr(3));
} else {
arch_int = 0;
}
return arch_int;
}
Stmt GemmSPPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] =
policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.lower")) {
auto prim_func =
Downcast<PrimFunc>((*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target,
T.thread_bounds, T.thread_var));
ICHECK(prim_func->attrs.defined());
auto global_symbol = prim_func->attrs.GetAttr<String>("global_symbol");
ICHECK(global_symbol.has_value());
if (prim_func->body.as<BlockRealizeNode>()) {
BlockRealize block_realize = Downcast<BlockRealize>(prim_func->body);
auto block = block_realize->block;
{
BlockNode *n = block.CopyOnWrite();
n->name_hint = global_symbol.value();
}
return BlockRealize(block_realize->iter_values, block_realize->predicate,
block);
}
// warp with block realize node
return BlockRealize(
/*iter_values=*/Array<PrimExpr>(),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/global_symbol.value(), prim_func->body));
} else {
LOG(FATAL) << "No lower function found for gemm_sp_py";
}
}
LayoutMap GemmSPPyNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (completed_)
return {};
LayoutMap results;
if (const auto f = ffi::Function::GetGlobal("tl.gemm_sp_py.infer_layout")) {
results = Downcast<LayoutMap>(
(*f)(tvm::ffi::GetRef<GemmSPPy>(this), T.target, T.thread_bounds));
} else {
LOG(FATAL) << "No infer layout function found for gemm_sp_py";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_TILE_OP(GemmSPPy, gemm_sp_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TVM_FFI_STATIC_INIT_BLOCK() { GemmSPPyNode::RegisterReflection(); }
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_sp_py.h
* \brief Define gemm_sp_py operator.
*
*/
// TODO: @botbw: remove redundant code with gemm_py.h
#ifndef TVM_TL_OP_GEMM_SP_PY_H_
#define TVM_TL_OP_GEMM_SP_PY_H_
#include "gemm_sp.h"
#include "operator.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSPPyNode : public TileOperatorNode {
public:
bool CheckWGMMA() const;
tir::Buffer A, E, B, C;
// pointer to the A, E, B, C
BufferRegion aRegion_, eRegion_, bRegion_, cRegion_;
bool trans_A, trans_B, trans_E;
int M, N, K;
int stride_A, stride_B;
int offset_A, offset_B;
PrimExpr clear_accum = const_false();
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
// use GemmWarp Policy here as the atom size are flexible in v2
mutable GemmWarpPolicy policy;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.GemmSPPy", GemmSPPyNode,
TileOperatorNode);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GemmSPPyNode>()
.def_ro("A", &GemmSPPyNode::A)
.def_ro("E", &GemmSPPyNode::E)
.def_ro("B", &GemmSPPyNode::B)
.def_ro("C", &GemmSPPyNode::C)
.def_ro("aRegion", &GemmSPPyNode::aRegion_)
.def_ro("eRegion", &GemmSPPyNode::eRegion_)
.def_ro("bRegion", &GemmSPPyNode::bRegion_)
.def_ro("cRegion", &GemmSPPyNode::cRegion_)
.def_ro("trans_A", &GemmSPPyNode::trans_A)
.def_ro("trans_B", &GemmSPPyNode::trans_B)
.def_ro("trans_E", &GemmSPPyNode::trans_E)
.def_ro("M", &GemmSPPyNode::M)
.def_ro("N", &GemmSPPyNode::N)
.def_ro("K", &GemmSPPyNode::K)
.def_ro("stride_A", &GemmSPPyNode::stride_A)
.def_ro("stride_B", &GemmSPPyNode::stride_B)
.def_ro("offset_A", &GemmSPPyNode::offset_A)
.def_ro("offset_B", &GemmSPPyNode::offset_B)
.def_ro("clear_accum", &GemmSPPyNode::clear_accum)
.def_ro("kPack", &GemmSPPyNode::kPack)
.def_ro("wg_wait", &GemmSPPyNode::wg_wait)
.def_ro("policy", &GemmSPPyNode::policy);
}
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override;
LayoutMap InferLayout(const LayoutInferArgs &T,
InferLevel level) const override;
TileOperator Clone() const;
private:
// Target GEMM instruction
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
};
class GemmSPPy : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator,
GemmSPPyNode);
TVM_DLL GemmSPPy(Array<PrimExpr> args);
static const Op &Get();
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_SP_PY_H_
\ No newline at end of file
......@@ -33,6 +33,7 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", pow_of_int_op)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) {
......@@ -59,7 +60,8 @@ TVM_REGISTER_OP("tl.infinity")
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op);
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op)
.set_attr<FLowerIntrinsic>("hip.FLowerIntrinsic", infinity_op);
} // namespace tl
} // namespace tvm
......@@ -24,16 +24,14 @@ using namespace tir;
*
* @param call The TIR Call whose operator and arguments will be used to build
* the TileOperator.
* @param vmap Buffer mapping passed through to the builder to resolve buffer
* references.
* @return TileOperator The constructed TileOperator, or a default (empty)
* TileOperator if no builder exists.
*/
TileOperator ParseOperator(Call call, BufferMap vmap) {
TileOperator ParseOperator(Call call) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
auto tile_op = op_map[op](call->args, vmap);
auto tile_op = op_map[op](call->args);
ICHECK(tile_op.defined());
return tile_op;
}
......@@ -48,14 +46,13 @@ TileOperator ParseOperator(Call call, BufferMap vmap) {
* Otherwise returns a default-constructed (empty) TileOperator.
*
* @param stmt TIR statement to inspect; expected to be an Evaluate of a Call.
* @param vmap Mapping of buffer variables used when building the operator.
* @return TileOperator Parsed operator on success, or a default (empty)
* TileOperator if `stmt` is not an Evaluate(Call).
*/
TileOperator ParseOperator(Stmt stmt, BufferMap vmap) {
TileOperator ParseOperator(Stmt stmt) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(tvm::ffi::GetRef<Call>(call), vmap);
return ParseOperator(tvm::ffi::GetRef<Call>(call));
}
return TileOperator();
}
......
......@@ -39,6 +39,9 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
};
struct LayoutInferArgs {
......@@ -48,6 +51,9 @@ struct LayoutInferArgs {
arith::Analyzer *analyzer;
bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
};
class TileOperator;
......@@ -72,23 +78,20 @@ public:
Var GetVarFromAccessPtr(const PrimExpr &expr);
TileOperator ParseOperator(Call call, BufferMap vmap);
TileOperator ParseOperator(Stmt stmt, BufferMap vmap);
TileOperator ParseOperator(Call call);
TileOperator ParseOperator(Stmt stmt);
using OpBuilderFunc =
ffi::TypedFunction<TileOperator(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
static const Op &op = Op::Get("tl.tileop." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
TVM_REGISTER_OP("tl.tileop." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> args, BufferMap vmap) { \
return Entry(args, vmap); \
})
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
} // namespace tl
} // namespace tvm
......
......@@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const {
return ParallelOp(op);
}
void ParallelOpNode::ExpandLetBindings(
const Map<Var, PrimExpr> &let_var_to_expr) {
if (let_var_to_expr.empty())
return;
// Helper function to recursively find BufferLoads through let bindings
std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!indice_map_.count(bl->buffer)) {
indice_map_.Set(bl->buffer, bl->indices);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
expand(let_var_to_expr[var]);
}
}
});
};
// Scan all let bindings
for (const auto &[var, expr] : let_var_to_expr) {
expand(expr);
}
}
Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
......@@ -214,6 +242,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined())
return {};
// Expand let bindings to find fragment buffer accesses
if (!T.let_var_to_expr.empty()) {
const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
}
if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that should be complicated replicated.
......@@ -252,17 +286,18 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
forward_vars.push_back(
IterVar(Range(0, s), Var(), IterVarType::kDataPar));
}
Array<PrimExpr> forward_index;
for (const auto &iv : forward_vars) {
forward_index.push_back(iv->var);
}
Var rep;
auto rep_iter =
IterVar({0, T.thread_bounds->extent}, rep, IterVarType::kDataPar);
// Use default fragment indexing (single output dim) to
// stay consistent with other ops (e.g., ReduceOp), and
// bind the thread range for comparability.
const PrimExpr &forward_thread = rep;
results.Set(buffer, Fragment(forward_vars, forward_index,
forward_thread, rep_iter));
auto frag = Fragment(forward_vars, /*forward_index=*/{}, forward_thread,
rep_iter)
->BindThreadRange(T.thread_bounds);
results.Set(buffer, frag);
}
}
return results;
......@@ -452,8 +487,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// As the pass will do post processing to the layout
auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1;
......@@ -562,6 +596,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} else {
return {};
}
// check loop_layout_ is injective
auto injective_res = loop_layout_->DetectInjective();
if (!injective_res->errors.empty()) {
std::ostringstream oss;
oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
<< '\n'
<< " errors: " << injective_res->errors << '\n'
<< " loop AST: " << root_;
throw LoopLayoutInjectiveException(oss.str());
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
......
......@@ -24,15 +24,6 @@ namespace tl {
using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
......@@ -114,6 +105,10 @@ private:
void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
// Expand let bindings to find fragment buffer accesses and add them to
// indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0],
// ...)
void ExpandLetBindings(const Map<Var, PrimExpr> &let_var_to_expr);
// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
......
......@@ -14,60 +14,24 @@
#include "../op/parallel.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "region.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/stmt.h"
#include "utils.h"
namespace tvm {
namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so Reduce can uniformly consume regions.
static BufferRegion NormalizeToBufferRegion(const PrimExpr &arg,
const BufferMap &vmap) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// NormalizeToBufferRegion moved to src/op/utils.{h,cc}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// MakeAccessPtrFromRegion moved to src/op/utils.{h,cc}
// Case 3: Call nodes (only tl.region)
if (const auto *call = arg.as<CallNode>()) {
// tl.region(...) — reconstruct via RegionOp
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args, vmap);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
}
LOG(FATAL) << "Unsupported argument for BufferRegion in reduce: " << arg;
throw; // Unreachable
}
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
ReduceOp::ReduceOp(Array<PrimExpr> args) {
ObjectPtr<ReduceOpNode> node = tvm::ffi::make_object<ReduceOpNode>();
// Accept BufferRegion/BufferLoad/tl.region for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0], vmap);
node->dstRegion_ = NormalizeToBufferRegion(args[1], vmap);
// Accept BufferRegion/BufferLoad for src/dst
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
std::string reduce_type = args[2].as<StringImm>().value()->value;
......@@ -231,6 +195,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto dst_scope = this->dst.scope();
if (src_scope == "local.fragment" && dst_scope == "local.fragment") {
Buffer src_buffer = get_buffer(this->src);
Buffer dst_buffer = get_buffer(this->dst);
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
......@@ -513,12 +478,22 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}
TIR_REGISTER_TL_OP(ReduceOp, reduce)
TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
// Normalize "Buffer" to BufferRegion. Use the shape of the buffer as the
// ranges.
static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) {
Array<Range> ranges;
for (PrimExpr extent : buf->shape) {
ranges.push_back(Range(IntImm(extent->dtype, 0), extent));
}
return BufferRegion(buf, ranges);
}
CumSumOp::CumSumOp(Array<PrimExpr> args) {
/// CumSum constructor arguments:
/// - src: input buffer
/// - dst: output buffer
......@@ -526,11 +501,19 @@ CumSumOp::CumSumOp(Array<PrimExpr> args, BufferMap vmap) {
/// - reverse: whether to cumsum in reverse order
CHECK_EQ(args.size(), 4);
ObjectPtr<CumSumOpNode> node = tvm::ffi::make_object<CumSumOpNode>();
node->src = vmap[GetVarFromAccessPtr(args[0])];
node->dst = vmap[GetVarFromAccessPtr(args[1])];
// node->src = vmap[GetVarFromAccessPtr(args[0])];
// node->dst = vmap[GetVarFromAccessPtr(args[1])];
node->srcRegion_ = NormalizeToBufferRegion(args[0]);
node->dstRegion_ = NormalizeToBufferRegion(args[1]);
node->src = node->srcRegion_->buffer;
node->dst = node->dstRegion_->buffer;
node->dim = args[2].as<IntImm>().value()->value;
node->reverse = args[3].as<Bool>().value();
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()));
CHECK_LT(node->dim, static_cast<int>(node->src->shape.size()))
<< "The dim of cumsum should be less than the number of dimensions. Got "
"dim="
<< node->dim << ", but src has " << node->src->shape.size() << " dims.";
data_ = std::move(node);
}
......@@ -545,19 +528,29 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss;
auto threads = T.thread_bounds->extent;
Array<PrimExpr> args;
int ndim = static_cast<int>(src->shape.size());
// Build access pointers from regions locally
PrimExpr srcPtr = MakeAccessPtrFromRegion(srcRegion_, 1);
PrimExpr dstPtr = MakeAccessPtrFromRegion(dstRegion_, 2);
// Use region extents instead of buffer shape for correct slice handling
Array<PrimExpr> src_extents;
for (const auto &range : srcRegion_->region) {
src_extents.push_back(range->extent);
}
int ndim = static_cast<int>(src_extents.size());
if (ndim == 1) {
ICHECK_EQ(dim, 0) << "Cumulative sum over a 1D buffer only supports dim "
"= 0.";
ss << "tl::CumSum1D<" << threads << ", " << (reverse ? "true" : "false")
<< ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0]};
args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0]};
} else if (ndim == 2) {
ss << "tl::CumSum2D<" << threads << ", " << dim << ", "
<< (reverse ? "true" : "false") << ">::run";
args = {StringImm(ss.str()), src.access_ptr(1), dst.access_ptr(3),
src->shape[0], src->shape[1]};
args = {StringImm(ss.str()), srcPtr, dstPtr, src_extents[0],
src_extents[1]};
} else {
LOG(FATAL) << "CumSum currently supports only 1D or 2D buffers, got "
<< ndim << "D.";
......@@ -576,7 +569,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}
TIR_REGISTER_TL_OP(CumSumOp, cumsum)
TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
......
......@@ -125,7 +125,7 @@ class ReduceOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator,
ReduceOpNode);
TVM_DLL ReduceOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL ReduceOp(Array<PrimExpr> args);
static const Op &Get();
};
......@@ -133,8 +133,10 @@ public:
class CumSumOpNode : public TileOperatorNode {
public:
tir::Buffer src, dst; ///< Source and destination buffers
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
// Optional: keep the original regions used to construct this op
BufferRegion srcRegion_, dstRegion_;
int dim; ///< Dimension along which to compute cumulative sum
bool reverse; ///< Whether to compute in reverse order
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.CumSumOp", CumSumOpNode,
TileOperatorNode);
......@@ -143,6 +145,8 @@ public:
refl::ObjectDef<CumSumOpNode>()
.def_ro("src", &CumSumOpNode::src)
.def_ro("dst", &CumSumOpNode::dst)
.def_ro("srcRegion", &CumSumOpNode::srcRegion_)
.def_ro("dstRegion", &CumSumOpNode::dstRegion_)
.def_ro("dim", &CumSumOpNode::dim)
.def_ro("reverse", &CumSumOpNode::reverse);
}
......@@ -159,7 +163,7 @@ class CumSumOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator,
CumSumOpNode);
TVM_DLL CumSumOp(Array<PrimExpr> args, BufferMap vmap);
TVM_DLL CumSumOp(Array<PrimExpr> args);
static const Op &Get();
};
......
/*!
* \file tl/op/region.cc
* \brief Define region operator.
* \brief Define region operator (bridge to carry BufferRegion via Call args).
*
* Notes:
* - BufferLoad/Ramp cannot represent a general PrimExpr as a vector lane
* count. Dynamic extents like (H1 - H0) cannot be encoded as
* Ramp(lanes = H1 - H0), and lowering BufferRegion to BufferLoad loses the
* explicit extent information.
* - tl.region carries both mins and extents in Call args and lets the backend
* reconstruct a BufferRegion faithfully.
*/
#include "region.h"
......@@ -11,27 +18,7 @@ namespace tvm {
namespace tl {
using namespace tir;
/**
* @brief Construct a RegionOp from TL operator arguments.
*
* Parses the TL `region` operator call arguments to populate the RegionOpNode:
* - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension
* minima.
* - args[1] must be a constant integer used as the access mask.
* - args[2 + i] provides the extent for dimension `i`.
*
* The constructor validates that the number of load indices equals `args.size()
* - 2` and will abort via ICHECK on mismatch or if args[0] is not a
* `BufferLoad`.
*
* Parameters:
* - args: TL operator call arguments in the form
* [BufferLoad(min_i...), access_mask, extent_0, extent_1, ...,
* extent_{n-1}] where n = number of dimensions.
* - vmap: BufferMap passed through by the caller (not documented here as a
* generic utility).
*/
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
RegionOp::RegionOp(Array<PrimExpr> args) {
size_t n = args.size();
size_t ndim = n - 2;
auto load = args[0].as<BufferLoadNode>();
......@@ -39,10 +26,24 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim;
Array<Range> ranges;
// Rebuild per-axis ranges from mins (BufferLoad indices) and provided extents
for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i];
PrimExpr index = load->indices[i];
PrimExpr extent = args[2 + i];
ranges.push_back(Range::FromMinExtent(min, extent));
if (const auto *ramp = index.as<RampNode>()) {
const auto *stride_imm = ramp->stride.as<IntImmNode>();
ICHECK(stride_imm && stride_imm->value == 1)
<< "RegionOp expects stride-1 Ramp for index";
if (const auto *lanes_imm = ramp->lanes.as<IntImmNode>()) {
if (const auto *ext_imm = extent.as<IntImmNode>()) {
ICHECK_EQ(lanes_imm->value, ext_imm->value)
<< "Ramp lanes and provided extent must match";
}
}
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, extent));
}
}
ObjectPtr<RegionOpNode> node = tvm::ffi::make_object<RegionOpNode>();
node->buffer_ = load->buffer;
......@@ -51,26 +52,11 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}
/**
* @brief Create a copy of this RegionOpNode and return it as a TileOperator.
*
* @return TileOperator A new TileOperator that owns a copied RegionOpNode.
*/
TileOperator RegionOpNode::Clone() const {
auto op = tvm::ffi::make_object<RegionOpNode>(*this);
return RegionOp(op);
}
/**
* @brief Check whether the region spans the entire underlying buffer.
*
* Returns true if for every dimension the range minimum is zero and the
* range extent is structurally equal to the corresponding buffer shape
* dimension. Otherwise returns false.
*
* @return true if the region covers the full buffer in all dimensions; false
* otherwise.
*/
bool RegionOpNode::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min))
......@@ -81,39 +67,16 @@ bool RegionOpNode::IsFullRegion() const {
return true;
}
/**
* @brief Lower the region operator to a TIR statement.
*
* Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's
* evaluation path (currently `Evaluate(0)`).
*
* @param T Lowering context (provides buffers, producers/consumers and other
* environment required for lowering).
* @param analyzer Optional arithmetic analyzer used for simplification during
* lowering.
* @return Stmt The lowered TIR statement representing this region operation.
*/
Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0);
}
/**
* @brief Infers data layout for the region operator.
*
* This operator does not provide any layout inference; the function always
* returns an empty LayoutMap regardless of the provided arguments or inference
* level.
*
* @param T Layout inference arguments (ignored).
* @param level Inference granularity level (ignored).
* @return LayoutMap Empty map indicating no inferred layouts.
*/
LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
return {};
}
TIR_REGISTER_TL_OP(RegionOp, region)
TIR_REGISTER_TL_TILE_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
......
/*!
* \file tl/op/op.h
* \brief Tile library operations.
* \file tl/op/region.h
* \brief Tile memory region descriptor op (bridge to carry BufferRegion via
* Call args).
*
* Why tl.region instead of passing BufferRegion directly?
*
* - While TIR can represent a BufferRegion, when a BufferRegion is passed as a
* call argument through call_intrin/FFI, the Python->C++ conversion lowers it
* to a BufferLoad(indices). To encode an interval inside indices, the FFI
* typically uses Ramp(base, stride, lanes) to represent a contiguous slice.
* - Ramp(lanes) may only be a constant or vscale*k (scalable vector). A general
* PrimExpr (e.g., H1 - H0) is not allowed as lanes, so dynamic extents would
* make the lowered BufferLoad invalid.
* - Moreover, BufferLoad only carries indices, not per-axis extents. Downstream
* tile operators (e.g., tl.copy, tl.reduce) that require both min and extent
* cannot losslessly recover dynamic extents from a BufferLoad alone.
*
* tl.region is a small transport-only op that solves this:
* - The frontend packs buffer + mins (from BufferLoad.indices) + extents into
* Call args, allowing dynamic extents to be expressed explicitly.
* - The backend (NormalizeToBufferRegion) reconstructs a BufferRegion from the
* tl.region call without losing information.
* - The op itself carries no semantics in Lower/InferLayout and is only used as
* a bridge for argument passing.
*/
#ifndef TVM_TL_OP_REGION_H_
#define TVM_TL_OP_REGION_H_
#include "./operator.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
/**
* Tile operator representing a memory region (buffer + ranges) used by TL
* passes.
*
* Encapsulates the target tir::Buffer, the region extents as an Array<Range>,
* and an access mask that indicates permitted or intended accesses for lowering
* and layout inference.
*/
/**
* Lower this RegionOp into a TIR statement representing the region access.
*
* @param T Lowering-time arguments (e.g., loop/build context and value
* mappings).
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A tir::Stmt that implements the region access/mutation described by
* this operator.
*/
/**
* Infer the layout mapping for this region operator.
*
* Produces a LayoutMap describing how loop/axis indices map to buffer axes for
* layout-aware scheduling and subsequent operators.
*
* @param T Layout inference arguments (e.g., input layouts and shapes).
* @param level The inference detail level to use.
* @return A LayoutMap describing inferred mappings for the operator.
*/
/**
* Return true when this RegionOp represents the full buffer region (i.e.,
* ranges cover the entire buffer extent).
*/
/**
* Create a shallow copy of this operator as a TileOperator handle.
*
* @return A TileOperator that references a cloned RegionOpNode.
*/
/**
* Construct a RegionOp from argument expressions and a buffer map.
*
* @param args Positional expressions used to instantiate the operator
* (semantics depend on how RegionOp is invoked in TL pipelines).
* @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used
* during creation.
*/
/**
* Return the global Op registration for RegionOp.
*
* @return Reference to the registered tvm::Op describing the RegionOp.
*/
namespace tvm {
namespace tl {
......@@ -80,6 +42,12 @@ public:
Array<Range> ranges_;
int access_mask_;
/*!
* access_mask_ encodes the intended access type when the region is used as
* an argument to tile operators: 1=read, 2=write, 3=read-write. The mask is
* transport metadata only and does not affect lowering.
*/
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.RegionOp", RegionOpNode,
TileOperatorNode);
......@@ -107,8 +75,13 @@ class RegionOp : public TileOperator {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RegionOp, TileOperator,
RegionOpNode);
TVM_DLL RegionOp(Array<PrimExpr> args, BufferMap vmap);
/*!
* Build a RegionOp from call arguments:
* - args[0]: BufferLoad whose indices are per-axis minima.
* - args[1]: Integer access mask (1=r, 2=w, 3=rw).
* - args[2 + i]: Extent of axis i (supports dynamic PrimExpr).
*/
TVM_DLL RegionOp(Array<PrimExpr> args);
static const Op &Get();
};
......
......@@ -15,16 +15,19 @@ using runtime::DataType;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
bool enable_ws, enable_2cta;
};
inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { false, TCGEN5MMAMeta{0, 0, 0} }
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
false, TCGEN5MMAMeta { 0, 0, 0, false, false } \
}
#define SUCCESS(atom_m, atom_n, atom_k, use_ws, use_2cta) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k, use_ws, use_2cta } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
......@@ -34,39 +37,50 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
SUCCESS(128, atom_n, 16, false, false);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
SUCCESS(64, atom_n, 16, true, false);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
SUCCESS(32, atom_n, 16, true, false);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
} else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() ||
ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
SUCCESS(128, atom_n, 32, false, true);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32, false, false);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
SUCCESS(64, atom_n, 32, true, false);
for (int atom_n = 256; atom_n >= 8; atom_n -= 8)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32, false, false);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
SUCCESS(32, atom_n, 32, true, false);
FAIL;
} else {
FAIL;
......
/*!
* \file tl/op/utils.cc
* \brief Common utilities implementation for TL ops.
*/
#include "utils.h"
#include <tvm/tir/builtin.h>
namespace tvm {
namespace tl {
using namespace tir;
BufferRegion NormalizeToBufferRegion(const PrimExpr &arg) {
// Case 1: Already a BufferRegion
if (arg->IsInstance<BufferRegionNode>()) {
return Downcast<BufferRegion>(arg);
}
// Case 2: BufferLoad — convert indices to ranges (Ramp -> lanes, else
// extent=1)
if (const auto *load = arg.as<BufferLoadNode>()) {
Array<Range> ranges;
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ICHECK(ramp->stride.as<IntImmNode>()) << "Ramp stride must be IntImm";
ICHECK_EQ(ramp->stride.as<IntImmNode>()->value, 1)
<< "Only stride-1 Ramp is supported in region conversion";
ICHECK(ramp->lanes.as<IntImmNode>())
<< "Scalable vector lanes not supported in region conversion";
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, 1));
}
}
return BufferRegion(load->buffer, ranges);
}
// Case 3: tl.region(...) — reconstruct via RegionOp (bridge)
if (const auto *call = arg.as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
RegionOp region(call->args);
return BufferRegion(region->GetBuffer(), region->GetRanges());
}
LOG(FATAL) << "Unsupported argument for BufferRegion (expect "
"BufferLoad/BufferRegion/tl.region): "
<< arg;
}
LOG(FATAL) << "Unsupported argument for BufferRegion: " << arg;
throw; // Unreachable
}
PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region, int rw_mask,
bool require_2d) {
Buffer buf = region->buffer;
int ndim = static_cast<int>(buf->shape.size());
if (require_2d) {
ICHECK(ndim >= 2) << "Expect buffers with at least 2 dims";
}
PrimExpr offset, extent;
if (ndim == 1) {
// 1D: straightforward
auto axis = region->region[0];
offset = axis->min;
extent = axis->extent;
} else {
// Compute row-major strides
std::vector<PrimExpr> strides(ndim);
PrimExpr one = make_const(buf->shape[0].dtype(), 1);
PrimExpr cur = one;
for (int i = ndim - 1; i >= 0; --i) {
strides[i] = cur;
cur = cur * buf->shape[i];
}
// Offset: sum_{i in [0..ndim-3]} min_i * stride_i
offset = make_const(buf->shape[0].dtype(), 0);
for (int i = 0; i < ndim - 2; ++i) {
offset = offset + region->region[i]->min * strides[i];
}
// Extent: last two extents product (elements)
extent =
region->region[ndim - 2]->extent * region->region[ndim - 1]->extent;
}
// ptype and return handle
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/utils.h
* \brief Common utilities for TL ops.
*/
#ifndef TVM_TL_OP_UTILS_H_
#define TVM_TL_OP_UTILS_H_
#include "./operator.h"
#include "region.h"
#include <tvm/tir/buffer.h>
#include <tvm/tir/op.h>
namespace tvm {
namespace tl {
using namespace tir;
// Normalize an argument (BufferRegion/BufferLoad/tl.region)
// to BufferRegion so ops can uniformly consume regions.
// Note: tvm_access_ptr is no longer supported here.
TVM_DLL BufferRegion NormalizeToBufferRegion(const PrimExpr &arg);
// Build a tvm_access_ptr(handle) from a BufferRegion.
// - If `require_2d` is true, checks buffer ndim >= 2.
// - For 1D regions (when allowed), offset=min, extent=extent.
// - For ndim >= 2, offset sums all but last two dims using row-major strides,
// extent is product of the last two extents.
TVM_DLL PrimExpr MakeAccessPtrFromRegion(const BufferRegion &region,
int rw_mask, bool require_2d = false);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_UTILS_H_
/*
* Helper functions for nicer runtime error messages.
*/
#include "error_helpers.h"
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <sstream>
#include <string>
namespace tvm {
namespace tl {
// Return non-zero so that tvm_call_packed sites treat it as failure and return
// -1.
static int DTypeMismatch(const tvm::ffi::String &kernel_name,
const tvm::ffi::String &buffer_name,
int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << "kernel " << std::string(kernel_name) << " input "
<< std::string(buffer_name) << " dtype expected " << expect << ", but got "
<< actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
// Variant without names, to avoid passing extra raw strings through packed
// args.
static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
int64_t actual_lanes, int64_t expect_code,
int64_t expect_bits, int64_t expect_lanes) {
tvm::runtime::DataType actual(static_cast<int>(actual_code),
static_cast<int>(actual_bits),
static_cast<int>(actual_lanes));
tvm::runtime::DataType expect(static_cast<int>(expect_code),
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << "dtype mismatch: expected " << expect << ", but got " << actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
// Register packed versions, following the design in runtime.cc
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
// actual_code, actual_bits, actual_lanes,
// expect_code, expect_bits, expect_lanes)
refl::GlobalDef().def_packed(
tl::tvm_error_dtype_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, "
"actual_code, actual_bits, actual_lanes, "
<< "expect_code, expect_bits, expect_lanes";
auto kernel_name = args[0].cast<tvm::ffi::String>();
auto buffer_name = args[1].cast<tvm::ffi::String>();
int64_t actual_code = args[2].cast<int64_t>();
int64_t actual_bits = args[3].cast<int64_t>();
int64_t actual_lanes = args[4].cast<int64_t>();
int64_t expect_code = args[5].cast<int64_t>();
int64_t expect_bits = args[6].cast<int64_t>();
int64_t expect_lanes = args[7].cast<int64_t>();
// Reuse the helper to format the message
(void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits,
actual_lanes, expect_code, expect_bits,
expect_lanes);
// Provide a return value for completeness, then signal the error
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_ndim_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " ndim expected " << expect << ", but got "
<< got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_byte_offset_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " byte_offset expected " << expect
<< ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_device_type_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
const char *expect_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(expect));
const char *got_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(got));
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " device_type expected " << expect_str
<< ", but got " << got_str;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String
refl::GlobalDef().def_packed(
tl::tvm_error_null_ptr,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3)
<< "__tvm_error_null_ptr(kernel, buffer, field)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " expected non-NULL, but got NULL";
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_expect_eq,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 5)
<< "__tvm_error_expect_eq(kernel, buffer, field, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
int64_t expect = args[3].cast<int64_t>();
int64_t got = args[4].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field) << " expected "
<< expect << ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String [, reason:String]
refl::GlobalDef().def_packed(
tl::tvm_error_constraint_violation,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3 || args.size() == 4)
<< "__tvm_error_constraint_violation(kernel, buffer, field[, "
"reason])";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::string reason;
if (args.size() == 4) {
reason = args[3].cast<tvm::ffi::String>();
}
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " constraint not satisfied";
if (!reason.empty()) {
os << ": " << reason;
}
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// Legacy typed registrations for backward compatibility
refl::GlobalDef().def("tilelang_error_dtype_mismatch",
&tvm::tl::DTypeMismatch);
refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
&tvm::tl::DTypeMismatchNoNames);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/runtime/error_helpers.h
* \brief Error helper FFI names for TileLang runtime.
*/
#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_
#define TVM_TL_RUNTIME_ERROR_HELPERS_H_
namespace tvm {
namespace tl {
// Error helper packed functions
constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch";
constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch";
constexpr const char *tvm_error_byte_offset_mismatch =
"__tvm_error_byte_offset_mismatch";
constexpr const char *tvm_error_device_type_mismatch =
"__tvm_error_device_type_mismatch";
constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr";
constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq";
constexpr const char *tvm_error_constraint_violation =
"__tvm_error_constraint_violation";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_
......@@ -13,6 +13,12 @@
namespace tvm {
namespace tl {
#if 1
// Thread-local storage for restoring the L2 persisting cache limit
static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
#endif
#if (CUDA_MAJOR_VERSION >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss;
......@@ -91,19 +97,21 @@ struct TensorMapArgs {
// set device api
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle,
T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
// Register using the canonical names defined in runtime.h
refl::GlobalDef().def_packed(
tl::tvm_tensormap_create_tiled, [](PackedArgs args, Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
}
struct TensorMapIm2ColArgs {
......@@ -183,7 +191,7 @@ struct TensorMapIm2ColArgs {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
tl::tvm_tensormap_create_im2col, [](PackedArgs args, Any *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
......@@ -201,5 +209,141 @@ TVM_FFI_STATIC_INIT_BLOCK() {
#endif // (CUDA_MAJOR_VERSION >= 12)
//
// CUDA L2 Persisting Cache Access Policy Window helpers.
// Exposed as TVM FFI packed functions similar to TMA initialization.
//
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Set stream access policy window and adjust persisting L2 cache size
// Args:
// [0]: void* base_ptr (required)
// [1]: int64 num_bytes (required)
// [2]: float hit_ratio (optional, default 0.8)
// [3]: void* stream (optional, default 0 => default stream)
// [4]: int64 l2_limit_bytes (optional, default = num_bytes)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_set_access_policy_window,
[](PackedArgs args, Any *ret) {
ICHECK(args.size() >= 2) << "Expected at least base_ptr and num_bytes";
void *base_ptr = args[0].cast<void *>();
size_t num_bytes = static_cast<size_t>(args[1].cast<int64_t>());
float hit_ratio = 0.8f;
if (args.size() >= 3) {
// Accept double/float
hit_ratio = static_cast<float>(args[2].cast<double>());
}
CUstream stream = nullptr;
if (args.size() >= 4) {
stream = reinterpret_cast<CUstream>(args[3].cast<void *>());
}
size_t l2_limit_bytes = num_bytes;
if (args.size() >= 5) {
l2_limit_bytes = static_cast<size_t>(args[4].cast<int64_t>());
}
// Clamp requested limit to device capability
CUdevice device;
CUresult result = cuCtxGetDevice(&device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current CUDA device: " << result;
}
int max_persisting = 0;
result = cuDeviceGetAttribute(
&max_persisting, CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE,
device);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to query MAX_PERSISTING_L2_CACHE_SIZE: "
<< result;
}
if (max_persisting > 0 &&
l2_limit_bytes > static_cast<size_t>(max_persisting)) {
l2_limit_bytes = static_cast<size_t>(max_persisting);
}
// Save current limit to restore later
size_t init_persisting_l2_cache_size = 0;
result = cuCtxGetLimit(&init_persisting_l2_cache_size,
CU_LIMIT_PERSISTING_L2_CACHE_SIZE);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to get current persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_size = init_persisting_l2_cache_size;
__tl_prev_persisting_l2_cache_saved = true;
// Set new limit
result =
cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE, l2_limit_bytes);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set persisting L2 cache size limit: "
<< result;
}
// Apply access policy window to stream
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
stream_attribute.accessPolicyWindow.base_ptr = base_ptr;
stream_attribute.accessPolicyWindow.num_bytes = l2_limit_bytes;
stream_attribute.accessPolicyWindow.hitRatio = hit_ratio;
stream_attribute.accessPolicyWindow.hitProp =
CU_ACCESS_PROPERTY_PERSISTING;
stream_attribute.accessPolicyWindow.missProp =
CU_ACCESS_PROPERTY_STREAMING;
result = cuStreamSetAttribute(stream,
CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to set stream access policy window: " << result;
}
*ret = static_cast<int>(result);
});
// Reset stream access policy window and restore the previous L2 cache size
// Args:
// [0]: void* stream (optional, default 0)
refl::GlobalDef().def_packed(
tl::tvm_cuda_stream_reset_access_policy_window,
[](PackedArgs args, Any *ret) {
CUstream stream = nullptr;
if (args.size() >= 1) {
stream = reinterpret_cast<CUstream>(args[0].cast<void *>());
}
CUstreamAttrValue stream_attribute;
memset(&stream_attribute, 0, sizeof(stream_attribute));
// num_bytes = 0 disables the access policy window on the stream
stream_attribute.accessPolicyWindow.num_bytes = 0;
CUresult result = cuStreamSetAttribute(
stream, CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW,
&stream_attribute);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset stream access policy window: "
<< result;
}
result = cuCtxResetPersistingL2Cache();
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to reset persisting L2 cache lines: " << result;
}
if (__tl_prev_persisting_l2_cache_saved) {
result = cuCtxSetLimit(CU_LIMIT_PERSISTING_L2_CACHE_SIZE,
__tl_prev_persisting_l2_cache_size);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to restore persisting L2 cache size limit: "
<< result;
}
__tl_prev_persisting_l2_cache_saved = false;
}
*ret = static_cast<int>(result);
});
}
} // namespace tl
} // namespace tvm
......@@ -16,7 +16,13 @@ constexpr const char *tvm_tensormap_create_tiled =
constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col";
#endif // (CUDA_MAJOR_VERSION >= 12)
// CUDA stream access policy window helpers
constexpr const char *tvm_cuda_stream_set_access_policy_window =
"__tvm_cuda_stream_set_access_policy_window";
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
"__tvm_cuda_stream_reset_access_policy_window";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file
#endif // TVM_TL_RUNTIME_RUNTIME_H_
......@@ -3,6 +3,7 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment