Unverified Commit 2feaa41e authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Enhancement] Improve InjectAssumes logic and make assumes work after SplitHostDevice (#1405)

* [Refactor] Refactor InjectAssumes logic and make assumes work after SplitHostDevice

* address comments

* fix

* fix submodule

* fix

* fix 3rdparty
parent 0788feb8
...@@ -136,6 +136,7 @@ file(GLOB TILE_LANG_SRCS ...@@ -136,6 +136,7 @@ file(GLOB TILE_LANG_SRCS
src/*.cc src/*.cc
src/layout/*.cc src/layout/*.cc
src/transform/*.cc src/transform/*.cc
src/transform/common/*.cc
src/op/*.cc src/op/*.cc
src/target/utils.cc src/target/utils.cc
src/target/codegen_c_host.cc src/target/codegen_c_host.cc
......
/*!
* \file assume.cc
* \brief Utils on assume statements
*/
#include "assume.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
namespace tvm {
namespace tl {
using namespace tir;
std::optional<PrimExpr> GetAssumeExprInEvaluateForm(Stmt stmt) {
auto eval = stmt.as<EvaluateNode>();
if (!eval)
return std::nullopt;
auto call = eval->value.as<CallNode>();
if (!call)
return std::nullopt;
if (!call->op.same_as(builtin::assume()))
return std::nullopt;
return call->args[0];
}
bool IsAssumeInEvaluateForm(const Stmt &stmt) {
return GetAssumeExprInEvaluateForm(stmt).has_value();
}
} // namespace tl
} // namespace tvm
/*!
* \file assume.h
* \brief Utils on assume statements
*/
#ifndef TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#define TVM_TL_TRANSFORM_COMMON_ASSUME_H_
#include "tvm/tir/stmt.h"
#include <optional>
namespace tvm {
namespace tl {
using namespace tir;
// Get the expression inside an assume statement, if any. Returns nullopt if
// the statement is not an assume statement.
std::optional<PrimExpr> GetAssumeExprInEvaluateForm(Stmt stmt);
// Check if a statement is an assume statement.
bool IsAssumeInEvaluateForm(const Stmt &stmt);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_ASSUME_H_
\ No newline at end of file
/*!
* \file inject_assumes.cc
* \brief Inject assumes on buffer's shape boundary check. Also convert
* existing assumes to AttrNodes.
*/
#include "common/assume.h"
#include "tvm/arith/analyzer.h" #include "tvm/arith/analyzer.h"
#include "tvm/ffi/optional.h" #include "tvm/ffi/optional.h"
#include "tvm/ir/expr.h" #include "tvm/ir/expr.h"
...@@ -10,6 +16,7 @@ ...@@ -10,6 +16,7 @@
#include "tvm/tir/stmt.h" #include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h" #include "tvm/tir/stmt_functor.h"
#include "tvm/tir/transform.h" #include "tvm/tir/transform.h"
#include <sstream> #include <sstream>
namespace tvm::tl { namespace tvm::tl {
...@@ -27,11 +34,12 @@ public: ...@@ -27,11 +34,12 @@ public:
} }
private: private:
struct AssertCreator { struct AssumeCreator {
struct Item { struct Item {
PrimExpr expr; PrimExpr expr;
std::vector<Buffer> buffers; std::vector<Buffer> buffers;
}; };
tvm::StructuralHash sh; tvm::StructuralHash sh;
tvm::StructuralEqual se; tvm::StructuralEqual se;
// grouped by expr, since the amount of variadic shape symbols is usually // grouped by expr, since the amount of variadic shape symbols is usually
...@@ -53,6 +61,7 @@ private: ...@@ -53,6 +61,7 @@ private:
items[*it].buffers.push_back(buffer); items[*it].buffers.push_back(buffer);
} }
} }
void addBuffer(Buffer buf) { void addBuffer(Buffer buf) {
for (auto shape : buf->shape) { for (auto shape : buf->shape) {
if (shape->IsInstance<IntImmNode>()) if (shape->IsInstance<IntImmNode>())
...@@ -60,6 +69,7 @@ private: ...@@ -60,6 +69,7 @@ private:
addExpr(shape, buf); addExpr(shape, buf);
} }
} }
Stmt build(Stmt body) { Stmt build(Stmt body) {
auto analyzer = arith::Analyzer{}; auto analyzer = arith::Analyzer{};
for (const auto &e : items) { for (const auto &e : items) {
...@@ -79,32 +89,37 @@ private: ...@@ -79,32 +89,37 @@ private:
return body; return body;
} }
}; };
Stmt VisitStmt_(const DeclBufferNode *op) final { Stmt VisitStmt_(const DeclBufferNode *op) final {
auto body = VisitStmt(op->body); auto body = VisitStmt(op->body);
AssertCreator c; AssumeCreator c;
c.addBuffer(op->buffer); c.addBuffer(op->buffer);
return DeclBuffer(op->buffer, c.build(body), op->span); return DeclBuffer(op->buffer, c.build(body), op->span);
} }
std::optional<PrimExpr> getAssumeExpr(Stmt stmt) {
auto eval = stmt.as<EvaluateNode>();
if (!eval)
return std::nullopt;
auto call = eval->value.as<CallNode>();
if (!call)
return std::nullopt;
if (!call->op.same_as(builtin::assume()))
return std::nullopt;
return call->args[0];
}
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
struct AssumeGroup { struct AssumeGroup {
std::optional<PrimExpr> e; std::optional<PrimExpr> e;
std::vector<Stmt> stmts; std::vector<Stmt> stmts;
}; };
std::vector<AssumeGroup> groups = {AssumeGroup{std::nullopt, {}}}; std::vector<AssumeGroup> groups = {AssumeGroup{std::nullopt, {}}};
for (auto i = 0; i < op->seq.size(); i++) { for (size_t i = 0; i < op->seq.size(); i++) {
auto stmt = VisitStmt(op->seq[i]); auto stmt = VisitStmt(op->seq[i]);
if (auto e = getAssumeExpr(stmt)) { // Convert assume in evaluate form to assume attribute.
// By default, we have the following IR:
// T.assume(cond1)
// Stmt1
// Stmt2
// T.assume(cond2)
// This SeqStmt will be converted to:
// With(attr::tilelang_assume, cond1) {
// Stmt1
// Stmt2
// }
// With(attr::tilelang_assume, cond2) {
// ...
// }
if (auto e = GetAssumeExprInEvaluateForm(stmt)) {
groups.push_back(AssumeGroup{*e, {}}); groups.push_back(AssumeGroup{*e, {}});
} else { } else {
groups.back().stmts.push_back(stmt); groups.back().stmts.push_back(stmt);
...@@ -127,10 +142,14 @@ private: ...@@ -127,10 +142,14 @@ private:
: SeqStmt(groups[0].stmts); : SeqStmt(groups[0].stmts);
// return SeqStmt(groups[0].stmts); // return SeqStmt(groups[0].stmts);
} }
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
auto body = VisitStmt(op->body); auto body = VisitStmt(op->body);
AssertCreator c; AssumeCreator c;
if (root_node) {
// NOTE(chaofan): We only inject assumes from function arguments in the
// root block.
if (op->name_hint == "root") {
for (auto item : f->buffer_map) { for (auto item : f->buffer_map) {
c.addBuffer(item.second); c.addBuffer(item.second);
} }
...@@ -141,12 +160,13 @@ private: ...@@ -141,12 +160,13 @@ private:
for (auto item : op->match_buffers) { for (auto item : op->match_buffers) {
c.addBuffer(item->buffer); c.addBuffer(item->buffer);
} }
return Block(op->iter_vars, op->reads, op->writes, op->name_hint, return Block(op->iter_vars, op->reads, op->writes, op->name_hint,
c.build(body), op->init, op->alloc_buffers, op->match_buffers, c.build(body), op->init, op->alloc_buffers, op->match_buffers,
op->annotations, op->span); op->annotations, op->span);
} }
PrimFunc f; PrimFunc f;
bool root_node{true};
}; };
using namespace tir::transform; using namespace tir::transform;
......
...@@ -33,13 +33,24 @@ ...@@ -33,13 +33,24 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "common/assume.h"
#include "tir/analysis/var_use_def_analysis.h" #include "tir/analysis/var_use_def_analysis.h"
#include "tvm/node/cast.h"
#include "tvm/runtime/logging.h"
#include "tvm/tir/stmt.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace ffi; using namespace ffi;
namespace tir = tvm::tir; namespace tir = tvm::tir;
// This pass traverses the AST, split the target function into host part and
// device part and copies all assume attribute statements to the device side.
// 1. Traverse AST and collect all assume statements into host_assumes_.
// 2. Until the first AttrStmtNode with tvm::attr::kTarget.
// 3. Call SplitDeviceFunc, which will create a new device function and replace
// the original body with a call to that function.
class HostDeviceSplitter : public tir::StmtMutator { class HostDeviceSplitter : public tir::StmtMutator {
public: public:
explicit HostDeviceSplitter(IRModule *device_mod, explicit HostDeviceSplitter(IRModule *device_mod,
...@@ -51,10 +62,29 @@ public: ...@@ -51,10 +62,29 @@ public:
found_device_region_ = true; found_device_region_ = true;
auto device_target = op->node.as<tvm::Target>().value().WithoutHost(); auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
return SplitDeviceFunc(op->body, device_target); return SplitDeviceFunc(op->body, device_target);
} else if (op->attr_key == tir::attr::tilelang_assume) {
// NOTE(chaofan): the assumes collected here must be in host-side.
// This is because when the collector reaches the split region,
// it will start to split and return. For safety, we add a check here.
ICHECK(!found_device_region_)
<< "Assumes collection should not be in device region.";
// We first push back the outside assume, then visit the child.
// So when moving assumes to device side, we need to do the building
// process in a reverse order.
host_assumes_.push_back(op);
} }
return tir::StmtMutator::VisitStmt_(op); return tir::StmtMutator::VisitStmt_(op);
} }
tir::Stmt VisitStmt_(const tir::EvaluateNode *op) final {
auto stmt = GetRef<tir::Stmt>(op);
// There should be no assume in evaluate form after InjectAssumes.
ICHECK(!IsAssumeInEvaluateForm(stmt))
<< "Unexpected assume in evaluate form. Please run InjectAssumes pass "
"first.";
return tir::StmtMutator::VisitStmt_(op);
}
tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) { tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) {
return SplitDeviceFunc(std::move(body), std::move(device_target)); return SplitDeviceFunc(std::move(body), std::move(device_target));
} }
...@@ -64,6 +94,14 @@ public: ...@@ -64,6 +94,14 @@ public:
private: private:
bool found_device_region_{false}; bool found_device_region_{false};
Stmt wrapBodyWithHostSideAssumes(Stmt body) {
for (auto it = host_assumes_.rbegin(); it != host_assumes_.rend(); ++it) {
body =
AttrStmt((*it)->node, tir::attr::tilelang_assume, (*it)->value, body);
}
return body;
}
tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) { tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
auto [params, buffers_to_declare] = auto [params, buffers_to_declare] =
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> { [&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
...@@ -104,9 +142,14 @@ private: ...@@ -104,9 +142,14 @@ private:
kernel_ret_type = VoidType(); kernel_ret_type = VoidType();
} }
// Declare necessary buffers for the device side.
for (tir::Buffer buf : buffers_to_declare) { for (tir::Buffer buf : buffers_to_declare) {
body = tir::DeclBuffer(buf, std::move(body)); body = tir::DeclBuffer(buf, std::move(body));
} }
// Copy assumes from host-side to device-side.
body = wrapBodyWithHostSideAssumes(body);
tir::PrimFunc device_func(params, body, kernel_ret_type); tir::PrimFunc device_func(params, body, kernel_ret_type);
device_func = device_func =
WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
...@@ -138,6 +181,8 @@ private: ...@@ -138,6 +181,8 @@ private:
IRModule *device_mod_; IRModule *device_mod_;
// Generate new GlobalVar for the kernel // Generate new GlobalVar for the kernel
std::function<GlobalVar()> var_supply_; std::function<GlobalVar()> var_supply_;
// Collect assumes in host side
Array<const tir::AttrStmtNode *> host_assumes_;
}; };
tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod, tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod,
......
...@@ -92,7 +92,8 @@ def LegalizeNegativeIndex(): ...@@ -92,7 +92,8 @@ def LegalizeNegativeIndex():
def InjectAssumes(): def InjectAssumes():
"""Inject Assumes """Inject Assumes for natural shape boundary conditions. And convert Assumes in Evaluate(Call(...)) form
(tvm builtin assume call) to AttrNode form.
Returns: Returns:
------- -------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment