Unverified Commit 79730b11 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Enhance LetStmt handling in Vectorize Loop Pass (#1159)

* [Refactor] Enhance TLVectorizer with loop vectorization convenience method and improve let variable handling

* lint fix

* let test fix

* lint fix
parent feef9ef6
......@@ -33,6 +33,7 @@
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -208,6 +209,14 @@ public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
// Convenience entry to vectorize a loop body without exposing
// the mutator invocation pattern at call sites.
static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) {
TLVectorizer vec{var, var_lanes};
auto vec_stmt = vec(std::move(body));
return vec_stmt;
}
TLVectorizer(const Var &var, const PrimExpr &var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
......@@ -217,8 +226,9 @@ public:
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
auto scalarized_stmt = Scalarize(stmt);
need_scalarize_ = false;
return Scalarize(stmt);
return scalarized_stmt;
} else {
return ret;
}
......@@ -401,8 +411,8 @@ public:
if (var.same_as(var_)) {
return ramp_;
}
auto it = let_binding_.find(var);
if (it != let_binding_.end()) {
auto it = let_var_map_.find(var);
if (it != let_var_map_.end()) {
return it->second;
} else {
return std::move(var);
......@@ -478,7 +488,6 @@ public:
bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector();
if (!vectorizable) {
// Cannot vectorize this op
Array<PrimExpr> new_args;
......@@ -518,7 +527,6 @@ public:
if (!indices.same_as(op->indices)) {
BufferLoadNode *writer = load.CopyOnWrite();
writer->indices = indices;
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer);
}
......@@ -533,18 +541,20 @@ public:
// This is used to allow cases when we reuse a single let
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
auto it = let_var_map_.find(op->var);
if (it != let_var_map_.end()) {
ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values";
}
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[new_var] = value;
return Let(new_var, value, this->VisitExpr(op->body));
} else {
let_binding_[op->var] = op->var;
let_var_map_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
......@@ -654,17 +664,20 @@ public:
// LetStmt
Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var))
ICHECK(!let_var_map_.count(op->var))
<< "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
let_var_map_[op->var] = new_var;
// Record mapping from the new var to its bound value
let_value_binding_[op->var] = op->value;
let_value_binding_[new_var] = value;
return LetStmt(new_var, value, this->VisitStmt(op->body));
} else {
let_binding_[op->var] = op->var;
let_var_map_[op->var] = op->var;
let_value_binding_[op->var] = value;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
......@@ -689,8 +702,27 @@ public:
// scalarize the statement
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
Var idx(var_->name_hint + "_s", var_->dtype);
// Find all Vars in stmt that are keys in let_value_binding_
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_let_bound_vars;
PostOrderVisit(stmt, [this, &used_let_bound_vars](const ObjectRef &node) {
if (const auto *v = node.as<VarNode>()) {
Var var = GetRef<Var>(v);
if (let_value_binding_.count(var)) {
used_let_bound_vars.insert(var);
}
}
});
stmt = Substitute(stmt, {{var_, idx}});
if (!used_let_bound_vars.empty()) {
for (const auto &v : used_let_bound_vars) {
// Bind the existing var v to its value around the stmt scope
auto new_value = Substitute(let_value_binding_.at(v), {{var_, idx}});
stmt = LetStmt(v, new_value, stmt);
}
}
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
......@@ -707,8 +739,11 @@ private:
PrimExpr ramp_;
// flag to mark requirement of scalarization.
bool need_scalarize_{false};
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// Let var mapping
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_var_map_;
// Let value binding: map new_var -> value
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
let_value_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
......@@ -806,7 +841,7 @@ public:
<< " for target " << Target::Current();
}
ICHECK(is_zero(op->min));
return TLVectorizer(op->loop_var, op->extent)(op->body);
return TLVectorizer::Vectorize(op->loop_var, op->extent, op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
......
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
def test_let_vectorize_load():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
b: T.float32x4 = A[0, 0:4]
A[0, 4:8] = b
mod = tvm.IRModule({"main": main})
mod = tvm.compile(mod, target="cuda")
assert "float4 b" in mod.mod.imported_modules[0].get_source()
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment