Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
/*!
* \file assume.h
* \brief Utils on assume statements
*/
#ifndef TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#define TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#include "tvm/tir/stmt.h"
#include <optional>
namespace tvm {
namespace tl {
using namespace tir;
// Get the expression inside an assume statement, if any. Returns nullopt if
// the statement is not an assume statement.
std::optional<PrimExpr> GetAssumeExprInEvaluateForm(Stmt stmt);
// Check if a statement is an assume statement.
bool IsAssumeInEvaluateForm(const Stmt &stmt);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_ASSUME_H_
\ No newline at end of file
/*
* Hoist tl.non_restrict_params block annotation(s) to PrimFunc attribute.
*
* Previously, we only looked at the root block. This version recursively
* scans all blocks, unions any tl.non_restrict_params entries it finds,
* merges with any existing PrimFunc-level attribute, then writes the
* deduplicated result back to the PrimFunc attrs. This makes annotation
* placement within the function body flexible for frontends.
*/
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tvm::tir;
class NonRestrictCollector : public StmtVisitor {
public:
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> Result() const {
Array<Var> out;
out.reserve(collected_.size());
for (const Var &v : collected_)
out.push_back(v);
return out;
}
private:
static std::string NormalizeName(const std::string &s) {
if (s.size() >= 8 && s.rfind("_handle") == s.size() - 7) {
return s.substr(0, s.size() - 7);
}
return s;
}
void MaybeInsert(const Var &v) {
if (!v.defined())
return;
const VarNode *p = v.get();
if (seen_ptr_.count(p))
return;
// Also dedup by normalized name to be robust w.r.t recreated Vars
std::string norm = NormalizeName(v->name_hint);
if (seen_name_.count(norm))
return;
seen_ptr_.insert(p);
seen_name_.insert(std::move(norm));
collected_.push_back(v);
}
void VisitStmt_(const BlockNode *op) final {
auto it = op->annotations.find(attr::kNonRestrictParams);
if (it != op->annotations.end()) {
if (const auto *arr = (*it).second.as<ffi::ArrayObj>()) {
// Downcast directly to Array<Var> for convenience
Array<Var> vars = tvm::Downcast<Array<Var>>((*it).second);
for (const Var &v : vars) {
MaybeInsert(v);
}
}
}
// Recurse into child statements
StmtVisitor::VisitStmt_(op);
}
std::vector<Var> collected_;
std::unordered_set<const VarNode *> seen_ptr_;
std::unordered_set<std::string> seen_name_;
};
static PrimFunc HoistNonRestrictParams(PrimFunc f) {
if (!f.defined())
return f;
NonRestrictCollector collector;
collector.Collect(f->body);
Array<Var> from_blocks = collector.Result();
// Merge with any existing PrimFunc-level attribute if present
if (auto opt_existing = f->GetAttr<Array<Var>>(attr::kNonRestrictParams)) {
for (const Var &v : opt_existing.value()) {
// Reuse the collector's dedup logic by temporarily constructing a new
// collector Alternatively, do a small inline dedup mirroring MaybeInsert
// Here we inline a simplified pointer-based dedup plus name-based
// fallback
bool exists = false;
for (const Var &cur : from_blocks) {
if (cur.get() == v.get() || cur->name_hint == v->name_hint) {
exists = true;
break;
}
}
if (!exists)
from_blocks.push_back(v);
}
}
if (from_blocks.empty())
return f;
return WithAttr(std::move(f), attr::kNonRestrictParams,
std::move(from_blocks));
}
namespace transform {
tvm::transform::Pass HoistNonRestrictParams() {
auto pass_func = [](PrimFunc f, const IRModule &,
const tvm::transform::PassContext &) {
return tvm::tl::HoistNonRestrictParams(std::move(f));
};
return tvm::tir::transform::CreatePrimFuncPass(
pass_func, 0, "tl.HoistNonRestrictParams", {});
}
} // namespace transform
} // namespace tl
} // namespace tvm
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.HoistNonRestrictParams",
tvm::tl::transform::HoistNonRestrictParams);
}
/*!
* \file inject_assumes.cc
* \brief Inject assumes on buffer's shape boundary check. Also convert
* existing assumes to AttrNodes.
*/
#include "common/assume.h"
#include "tvm/arith/analyzer.h"
#include "tvm/ffi/optional.h"
#include "tvm/ir/expr.h"
......@@ -6,9 +12,11 @@
#include "tvm/node/structural_hash.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/op.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h"
#include <sstream>
namespace tvm::tl {
......@@ -26,11 +34,12 @@ public:
}
private:
struct AssertCreator {
struct AssumeCreator {
struct Item {
PrimExpr expr;
std::vector<Buffer> buffers;
};
tvm::StructuralHash sh;
tvm::StructuralEqual se;
// grouped by expr, since the amount of variadic shape symbols is usually
......@@ -52,6 +61,7 @@ private:
items[*it].buffers.push_back(buffer);
}
}
void addBuffer(Buffer buf) {
for (auto shape : buf->shape) {
if (shape->IsInstance<IntImmNode>())
......@@ -59,10 +69,12 @@ private:
addExpr(shape, buf);
}
}
Stmt build(Stmt body) {
auto analyzer = arith::Analyzer{};
for (const auto &e : items) {
auto simplified = analyzer.Simplify(GT(e.expr, 0));
auto simplified =
analyzer.Simplify(GT(e.expr, make_zero(e.expr->dtype)));
std::stringstream ss;
ss << "Buffer shape should be greater than 0: shape `" << e.expr
<< "` from buffer ";
......@@ -77,32 +89,37 @@ private:
return body;
}
};
Stmt VisitStmt_(const DeclBufferNode *op) final {
auto body = VisitStmt(op->body);
AssertCreator c;
AssumeCreator c;
c.addBuffer(op->buffer);
return DeclBuffer(op->buffer, c.build(body), op->span);
}
std::optional<PrimExpr> getAssumeExpr(Stmt stmt) {
auto eval = stmt.as<EvaluateNode>();
if (!eval)
return std::nullopt;
auto call = eval->value.as<CallNode>();
if (!call)
return std::nullopt;
if (!call->op.same_as(builtin::assume()))
return std::nullopt;
return call->args[0];
}
Stmt VisitStmt_(const SeqStmtNode *op) final {
struct AssumeGroup {
std::optional<PrimExpr> e;
std::vector<Stmt> stmts;
};
std::vector<AssumeGroup> groups = {AssumeGroup{std::nullopt, {}}};
for (auto i = 0; i < op->seq.size(); i++) {
for (size_t i = 0; i < op->seq.size(); i++) {
auto stmt = VisitStmt(op->seq[i]);
if (auto e = getAssumeExpr(stmt)) {
// Convert assume in evaluate form to assume attribute.
// By default, we have the following IR:
// T.assume(cond1)
// Stmt1
// Stmt2
// T.assume(cond2)
// This SeqStmt will be converted to:
// With(attr::tilelang_assume, cond1) {
// Stmt1
// Stmt2
// }
// With(attr::tilelang_assume, cond2) {
// ...
// }
if (auto e = GetAssumeExprInEvaluateForm(stmt)) {
groups.push_back(AssumeGroup{*e, {}});
} else {
groups.back().stmts.push_back(stmt);
......@@ -125,10 +142,14 @@ private:
: SeqStmt(groups[0].stmts);
// return SeqStmt(groups[0].stmts);
}
Stmt VisitStmt_(const BlockNode *op) final {
auto body = VisitStmt(op->body);
AssertCreator c;
if (root_node) {
AssumeCreator c;
// NOTE(chaofan): We only inject assumes from function arguments in the
// root block.
if (op->name_hint == "root") {
for (auto item : f->buffer_map) {
c.addBuffer(item.second);
}
......@@ -139,12 +160,13 @@ private:
for (auto item : op->match_buffers) {
c.addBuffer(item->buffer);
}
return Block(op->iter_vars, op->reads, op->writes, op->name_hint,
c.build(body), op->init, op->alloc_buffers, op->match_buffers,
op->annotations, op->span);
}
PrimFunc f;
bool root_node{true};
};
using namespace tir::transform;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment