Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
...@@ -32,29 +32,32 @@ namespace tl { ...@@ -32,29 +32,32 @@ namespace tl {
using namespace tir; using namespace tir;
class BufferIndiceSimplify : public StmtExprMutator { class BufferIndiceSimplify : public StmtExprMutator {
public: public:
BufferIndiceSimplify(arith::Analyzer* analyzer) : analyzer_(analyzer) {} BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
private: private:
PrimExpr VisitExpr_(const BufferLoadNode* node) final { PrimExpr VisitExpr_(const BufferLoadNode *node) final {
auto visited = StmtExprMutator::VisitExpr_(node); auto visited = StmtExprMutator::VisitExpr_(node);
auto n = visited.as<BufferLoad>().value(); auto n = visited.as<BufferLoad>().value();
auto nptr = n.CopyOnWrite(); auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); }); nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); });
return n; return n;
} }
Stmt VisitStmt_(const BufferStoreNode* node) final { Stmt VisitStmt_(const BufferStoreNode *node) final {
auto visited = StmtExprMutator::VisitStmt_(node); auto visited = StmtExprMutator::VisitStmt_(node);
auto n = visited.as<BufferStore>().value(); auto n = visited.as<BufferStore>().value();
auto nptr = n.CopyOnWrite(); auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); }); nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); });
return n; return n;
} }
arith::Analyzer* analyzer_; arith::Analyzer *analyzer_;
}; };
// Rewrite the parallel loop into a common loop, which is mapped to threads // Rewrite the parallel loop into a common loop, which is mapped to threads
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout) { For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout) {
ICHECK(loop_layout.defined()); ICHECK(loop_layout.defined());
ICHECK(thread_var.defined()); ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim(); int old_loop_depth = loop_layout->InputDim();
...@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo ...@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
Stmt body = op; Stmt body = op;
auto inv_loop = loop_layout->Inverse(); auto inv_loop = loop_layout->Inverse();
auto indices = inv_loop->Forward(vars.Map([](const Var& v) { return PrimExpr(v); })); auto indices =
inv_loop->Forward(vars.Map([](const Var &v) { return PrimExpr(v); }));
for (int i = 0; i < old_loop_depth; i++) { for (int i = 0; i < old_loop_depth; i++) {
ICHECK(body.as<For>().defined()); ICHECK(body.as<For>().defined());
For loop = body.as<For>().value(); For loop = body.as<For>().value();
...@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo ...@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
// substitute and re-construct the serial loop // substitute and re-construct the serial loop
body = Substitute(body, vmap); body = Substitute(body, vmap);
for (int i = new_loop_depth - 1; i >= 0; i--) { for (int i = new_loop_depth - 1; i >= 0; i--) {
body = body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], ForKind::kSerial, body); ForKind::kSerial, body);
analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i])); analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
} }
...@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo ...@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
} }
class LoopPramaUnroller : public StmtExprMutator { class LoopPramaUnroller : public StmtExprMutator {
public: public:
LoopPramaUnroller() = default; LoopPramaUnroller() = default;
private: private:
Stmt VisitStmt_(const ForNode* node) final { Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kSerial) { if (node->kind == ForKind::kSerial) {
For new_for = GetRef<For>(node); For new_for = GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite(); auto for_ptr = new_for.CopyOnWrite();
...@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator { ...@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator {
}; };
class LoopPartitioner : public StmtExprVisitor { class LoopPartitioner : public StmtExprVisitor {
public: public:
LoopPartitioner() = default; LoopPartitioner() = default;
Fragment Partition(For op, int num_thread, int vectorize_size) { Fragment Partition(For op, int num_thread, int vectorize_size) {
...@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor { ...@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor {
ICHECK(loop_size_full % vectorize_size == 0); ICHECK(loop_size_full % vectorize_size == 0);
PrimExpr access_idx = FloorDiv(flattened, vectorize_size); PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
PrimExpr thd = FloorMod(access_idx, num_thread); PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx = PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorDiv(access_idx, num_thread) * vectorize_size + FloorMod(flattened, vectorize_size); FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {}); return Fragment(loop_vars_, {idx}, {thd}, {});
} }
private: private:
void VisitStmt_(const ForNode* node) final { void VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) { if (node->kind == ForKind::kParallel) {
body_ = node->body; body_ = node->body;
loop_vars_.push_back(IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var, loop_vars_.push_back(
IterVarType::kDataPar)); IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
IterVarType::kDataPar));
} }
StmtExprVisitor::VisitStmt_(node); StmtExprVisitor::VisitStmt_(node);
} }
...@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) { ...@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) {
return unrolled; return unrolled;
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -36,13 +36,14 @@ namespace tl { ...@@ -36,13 +36,14 @@ namespace tl {
using namespace tir; using namespace tir;
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout); For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout);
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size); Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size);
For LoopPragmaUnroll(For stmt); For LoopPragmaUnroll(For stmt);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_LOOP_PARTITION_H_ #endif // TVM_TL_LOOP_PARTITION_H_
...@@ -30,10 +30,10 @@ ...@@ -30,10 +30,10 @@
#include <numeric> #include <numeric>
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h" #include "common/loop_vectorization_utils.h"
namespace tvm { namespace tvm {
...@@ -48,10 +48,10 @@ struct VectorizePlanResult { ...@@ -48,10 +48,10 @@ struct VectorizePlanResult {
}; };
class VectorizePlanner : public arith::IRVisitorWithAnalyzer { class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
public: public:
VectorizePlanner() = default; VectorizePlanner() = default;
int Plan(const For& node) { int Plan(const For &node) {
this->operator()(node); this->operator()(node);
// Always Enable vectorization // Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1; // if (!has_nonlocal_memory_access_) return 1;
...@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
PrimExpr GetCondition() { return condition_; } PrimExpr GetCondition() { return condition_; }
private: private:
void VisitStmt_(const ForNode* node) final { void VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent)); iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node); arith::IRVisitorWithAnalyzer::VisitStmt_(node);
} }
void VisitExpr_(const BufferLoadNode* node) final { void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn") node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true; has_nonlocal_memory_access_ = true;
if (node->buffer->shape.size() == 1 && node->buffer->shape[0].as<IntImmNode>()->value == 1) { if (node->buffer->shape.size() == 1 &&
node->buffer->shape[0].as<IntImmNode>()->value == 1) {
// TODO(lei): This should be improved as // TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register. // constant buffer that tl hack to use as local register.
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
...@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
void VisitStmt_(const BufferStoreNode* node) final { void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn") node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true; has_nonlocal_memory_access_ = true;
...@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitStmt_(node); return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
} }
void VisitStmt_(const IfThenElseNode* node) final { void VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition); CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node); return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
} }
void VisitExpr_(const CallNode* node) final { void VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) { if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]); CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) { } else if (node->op == builtin::call_extern()) {
...@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node); return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
} }
void CheckConditionVectorized(const PrimExpr& cond) { void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here // TODO: perform some checks here
} }
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer& buffer) { void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
if (!inner_for_) return; if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>(); auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr) return; if (!extent_ptr)
return;
const DataType& access_type = buffer->dtype; const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16 // i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = 128 / access_type.bits(); int max_vector_size = 128 / access_type.bits();
// so we should disable this GCD optimization // so we should disable this GCD optimization
...@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
auto last_dim = buffer->shape.back(); auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim); auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) { if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
...@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
elem_offset = elem_offset + indices[i] * stride; elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i]; stride = stride * buffer->shape[i];
} }
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, inner_for_->extent, while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
vector_size_, &analyzer_)) { inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2; vector_size_ /= 2;
} }
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) { } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
...@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
static const int vector_load_bits_max_ = 128; static const int vector_load_bits_max_ = 128;
const ForNode* inner_for_; const ForNode *inner_for_;
Map<Var, Range> iter_map_; Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128; int vector_size_ = 128;
...@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { ...@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
}; };
class VectorizeDynamicCallRemover : public StmtExprMutator { class VectorizeDynamicCallRemover : public StmtExprMutator {
public: public:
VectorizeDynamicCallRemover(Var inner_var, int vector_size) VectorizeDynamicCallRemover(Var inner_var, int vector_size)
: inner_var_(inner_var), vector_size_(vector_size) {} : inner_var_(inner_var), vector_size_(vector_size) {}
private: private:
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) { if (op->op.same_as(builtin::if_then_else())) {
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
...@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator { ...@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator {
}; };
class VectorizeRewriter : public StmtExprMutator { class VectorizeRewriter : public StmtExprMutator {
public: public:
VectorizeRewriter(VectorizePlanResult plan) VectorizeRewriter(VectorizePlanResult plan)
: vector_size_(plan.vector_size), condition_(plan.condition), dynamic_(plan.dynamic) {} : vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
private: private:
Stmt VisitStmt_(const ForNode* node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value(); For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var; auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent); auto extent_ptr = as_const_int(fnode->extent);
...@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator { ...@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK(extent % vector_size_ == 0) ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_; << "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) { if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized; fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode; return fnode;
...@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator { ...@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator {
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap); Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding, body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
return body; return body;
} }
} else { } else {
...@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator { ...@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator {
VectorizeDynamicCallRemover remover(inner_var, vector_size_); VectorizeDynamicCallRemover remover(inner_var, vector_size_);
body = remover(body); body = remover(body);
For vectorize_for = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); For vectorize_for =
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body); For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
For serial_for =
For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition, vectorize_for, serial_for); body = IfThenElse(condition, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding, body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
return body; return body;
} }
} else { } else {
...@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator { ...@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator {
} }
} }
const ForNode* inner_for_; const ForNode *inner_for_;
const int vector_size_; const int vector_size_;
const PrimExpr condition_; const PrimExpr condition_;
const bool dynamic_; const bool dynamic_;
}; };
int GetVectorizeSize(const For& loop) { return VectorizePlanner().Plan(loop); } int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
VectorizePlanResult GetVectorizePlanResult(const For& loop) { VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner; VectorizePlanner planner;
int vector_size = planner.Plan(loop); int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic(); bool dynamic = planner.GetDynamic();
...@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) { ...@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) {
return {vector_size, dynamic, condition}; return {vector_size, dynamic, condition};
} }
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size, bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
arith::Analyzer* analyzer) { int target_vectorized_size, arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1); ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) return true; if (target_vectorized_size == 1)
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), 0)) return false; return true;
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;
Var v0("v0"), v1("v1"); Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size)); analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size))); analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
PrimExpr expr_transformed = PrimExpr expr_transformed = analyzer->Simplify(
analyzer->Simplify(Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
...@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ ...@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
} }
} }
For VectorizeLoop(const For& loop, int vectorize_hint) { For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0}; VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) { if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop); res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size; vectorize_hint = res.vector_size;
} }
if (vectorize_hint == 1) return loop; if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res); auto rewriter = VectorizeRewriter(res);
return Downcast<For>(rewriter(loop)); return Downcast<For>(rewriter(loop));
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -35,13 +35,13 @@ namespace tl { ...@@ -35,13 +35,13 @@ namespace tl {
using namespace tir; using namespace tir;
int GetVectorizeSize(const For& loop); int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For& loop, int vectorize_hint = -1); For VectorizeLoop(const For &loop, int vectorize_hint = -1);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size, bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
arith::Analyzer* analyzer); int target_vectorized_size, arith::Analyzer *analyzer);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_LOOP_VECTORIZE_H_ #endif // TVM_TL_LOOP_VECTORIZE_H_
...@@ -37,15 +37,15 @@ namespace tl { ...@@ -37,15 +37,15 @@ namespace tl {
using namespace tir; using namespace tir;
class LowerHopperIntrin : public StmtExprMutator { class LowerHopperIntrin : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc& f) { static PrimFunc Substitute(PrimFunc &f) {
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
LowerHopperIntrin substituter; LowerHopperIntrin substituter;
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
for (auto [call, var] : substituter.desc_map_) { for (auto [call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack // Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
Call(DataType::Handle(), builtin::tvm_stack_alloca(), {StringImm("arg_value"), 16}); {StringImm("arg_value"), 16});
Array<PrimExpr> init_desc_args; Array<PrimExpr> init_desc_args;
if (call->op.same_as(CreateTMADescriptorOp())) { if (call->op.same_as(CreateTMADescriptorOp())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled)); init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
...@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator { ...@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator {
CHECK(0) << call->op; CHECK(0) << call->op;
} }
init_desc_args.push_back(var); init_desc_args.push_back(var);
init_desc_args.insert(init_desc_args.end(), call->args.begin(), call->args.end()); init_desc_args.insert(init_desc_args.end(), call->args.begin(),
Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args); call->args.end());
fptr->body = LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body})); Call init_desc =
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body =
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
} }
return f; return f;
} }
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
// Insert the prefetch TMA descriptor statement TO the beginning of the kernel // Insert the prefetch TMA descriptor statement TO the beginning of the
// kernel
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
...@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator { ...@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator {
} else { } else {
Array<Stmt> stmt_seq; Array<Stmt> stmt_seq;
if (!init_mbarrier_calls_.empty()) { if (!init_mbarrier_calls_.empty()) {
auto alloc_mbarrier = Evaluate(Call(DataType::Handle(), builtin::create_barriers(), auto alloc_mbarrier =
{static_cast<int>(init_mbarrier_calls_.size())})); Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
{static_cast<int>(init_mbarrier_calls_.size())}));
stmt_seq.push_back(alloc_mbarrier); stmt_seq.push_back(alloc_mbarrier);
} }
auto stmts = prefetch_calls_; auto stmts = prefetch_calls_;
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), init_mbarrier_calls_.end()); stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
auto init_stmt = IfThenElse(EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]); init_mbarrier_calls_.end());
auto init_stmt = IfThenElse(
EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(init_stmt); stmt_seq.push_back(init_stmt);
if (!init_mbarrier_calls_.empty()) { if (!init_mbarrier_calls_.empty()) {
Stmt mem_sync = Evaluate( Stmt mem_sync =
Call(DataType::Handle(), builtin::tvm_storage_sync(), {StringImm("shared")})); Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")}));
stmt_seq.push_back(mem_sync); stmt_seq.push_back(mem_sync);
} }
stmt_seq.push_back(body); stmt_seq.push_back(body);
...@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator { ...@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op); return StmtExprMutator::VisitStmt_(op);
} }
PrimExpr VisitExpr_(const CallNode* call) final { PrimExpr VisitExpr_(const CallNode *call) final {
if (call->op.same_as(CreateTMADescriptorOp()) || if (call->op.same_as(CreateTMADescriptorOp()) ||
call->op.same_as(CreateTMAIm2ColDescriptorOp())) { call->op.same_as(CreateTMAIm2ColDescriptorOp())) {
Var var; Var var;
...@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator { ...@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator {
var = iter->second; var = iter->second;
} else { } else {
String name = call->args[2].as<Var>().value()->name_hint; String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc", PointerType(PrimType(cuTensorMapType()), "grid_constant")); var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var; desc_map_[GetRef<Call>(call)] = var;
prefetch_calls_.push_back(Evaluate(Call(DataType::Handle(), builtin::call_extern(), prefetch_calls_.push_back(
{StringImm("tl::prefetch_tma_descriptor"), var}))); Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
} }
return var; return var;
} else if (call->op.same_as(CreateListofMBarrierOp())) { } else if (call->op.same_as(CreateListofMBarrierOp())) {
...@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator { ...@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator {
int num_barriers = static_cast<int>(call->args.size()); int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) { for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {i}); PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {i});
init_mbarrier_calls_.push_back( init_mbarrier_calls_.push_back(Evaluate(
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[i]}))); {mbarrier, call->args[i]})));
} }
return 0; return 0;
} else if (call->op.same_as(SyncThreadsPartialOp())) { } else if (call->op.same_as(SyncThreadsPartialOp())) {
int barrier_id = init_mbarrier_calls_.size(); int barrier_id = init_mbarrier_calls_.size();
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {barrier_id}); PrimExpr mbarrier =
init_mbarrier_calls_.push_back( Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(), init_mbarrier_calls_.push_back(Evaluate(
{mbarrier, call->args[0]}))); Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]})));
return Call(DataType::Handle(), SyncThreadsPartialOp(), {mbarrier}); return Call(DataType::Handle(), SyncThreadsPartialOp(), {mbarrier});
} else { } else {
return StmtExprMutator::VisitExpr_(call); return StmtExprMutator::VisitExpr_(call);
} }
} }
private: private:
Array<Stmt> prefetch_calls_; Array<Stmt> prefetch_calls_;
Array<Stmt> init_mbarrier_calls_; Array<Stmt> init_mbarrier_calls_;
std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_; std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
...@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() { ...@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() {
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin") TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin")
.set_body_typed(LowerHopperIntrin); .set_body_typed(LowerHopperIntrin);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -27,10 +27,10 @@ ...@@ -27,10 +27,10 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/op.h" #include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h" #include "loop_partition.h"
namespace tvm { namespace tvm {
...@@ -38,8 +38,9 @@ namespace tl { ...@@ -38,8 +38,9 @@ namespace tl {
using namespace tir; using namespace tir;
static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) { static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); const auto *ptr_type =
TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
Type new_type; Type new_type;
// convert fragments to normal local buffer // convert fragments to normal local buffer
if (ptr_type->storage_scope == "local.fragment") { if (ptr_type->storage_scope == "local.fragment") {
...@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) { ...@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
} else { } else {
new_var = Var(buffer->data->name_hint, new_type); new_var = Var(buffer->data->name_hint, new_type);
} }
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {}, buffer->elem_offset, return Buffer(new_var, buffer->dtype, layout->OutputShape(), {},
buffer->name, buffer->data_alignment, buffer->offset_factor, buffer->buffer_type); buffer->elem_offset, buffer->name, buffer->data_alignment,
buffer->offset_factor, buffer->buffer_type);
} }
class LowerTileOpPass : arith::IRMutatorWithAnalyzer { class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
LowerTileOpPass substituter(&analyzer); LowerTileOpPass substituter(&analyzer);
// Trace the buffer map for tvm_access_ptr // Trace the buffer map for tvm_access_ptr
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end()); substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
for (const auto& [_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute"; ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value(); substituter.target_ = target.value();
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
return f; return f;
} }
private: private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
// Record the mapping from buffer data var to buffer for later lookup // Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) { for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer}); buffer_map_.insert({buffer->data, buffer});
...@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
} }
Map<Var, Layout> vmap; Map<Var, Layout> vmap;
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.at(attr::kLayoutMap).as<Map<Buffer, Layout>>().value(); auto layout_map = op->annotations.at(attr::kLayoutMap)
.as<Map<Buffer, Layout>>()
.value();
for (auto [buffer, layout] : layout_map) { for (auto [buffer, layout] : layout_map) {
buffer_remap_.Set(buffer, makeBufferWithLayout(buffer, layout)); buffer_remap_.Set(buffer, makeBufferWithLayout(buffer, layout));
layout_map_.Set(buffer, layout); layout_map_.Set(buffer, layout);
...@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]); block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
} }
} }
for (const auto& buffer : workspaces_) block_ptr->alloc_buffers.push_back(buffer); for (const auto &buffer : workspaces_)
block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear(); workspaces_.clear();
block_ptr->annotations.erase(attr::kLayoutMap); block_ptr->annotations.erase(attr::kLayoutMap);
return block; return block;
...@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
int CheckAndGetBufferRowSize(Buffer buffer) { int CheckAndGetBufferRowSize(Buffer buffer) {
CHECK(buffer->shape.size() >= 2) CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape << "The dimension of Buffer \"" << buffer->name << "\" with shape "
<< " should be at least 2"; << buffer->shape << " should be at least 2";
auto dim = buffer->shape.size(); auto dim = buffer->shape.size();
auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value; auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
return buffer_row_size; return buffer_row_size;
} }
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> offset = NullOpt, PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
Optional<PrimExpr> offset = NullOpt,
DataType dtype = DataType::Int(32)) { DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to // The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// smem_offset // accumulate it to smem_offset
CHECK(access_ptr->IsInstance<CallNode>()) CHECK(access_ptr->IsInstance<CallNode>())
<< "Invalid access ptr for permuted layout: " << access_ptr; << "Invalid access ptr for permuted layout: " << access_ptr;
auto access_ptr_call = Downcast<Call>(access_ptr); auto access_ptr_call = Downcast<Call>(access_ptr);
...@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
Array<PrimExpr> shape = load->buffer->shape; Array<PrimExpr> shape = load->buffer->shape;
CHECK_EQ(indices.size(), shape.size()) CHECK_EQ(indices.size(), shape.size())
<< "Indices size and shape size must match for general N-dimensional buffer " << "Indices size and shape size must match for general N-dimensional "
<< "but got indices size: " << indices.size() << " and shape size: " << shape.size(); "buffer "
<< "but got indices size: " << indices.size()
<< " and shape size: " << shape.size();
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
PrimExpr stride = 1; PrimExpr stride = 1;
...@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
stride *= shape[i]; stride *= shape[i];
} }
PrimExpr smem_offset = elem_offset + (offset.defined() ? offset.value() : 0); PrimExpr smem_offset =
elem_offset + (offset.defined() ? offset.value() : 0);
auto new_buffer = buffer_remap_[load->buffer]; auto new_buffer = buffer_remap_[load->buffer];
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(load->buffer->data)); auto buffer_map_iter =
buffer_map_.find(Downcast<Var>(load->buffer->data));
CHECK(buffer_map_iter != buffer_map_.end()) CHECK(buffer_map_iter != buffer_map_.end())
<< "The buffer corresponding to data Var " << access_ptr_call->args[0] << " is not found"; << "The buffer corresponding to data Var " << access_ptr_call->args[0]
<< " is not found";
int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second); int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
(void)buffer_row_size; (void)buffer_row_size;
...@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
PrimExpr remaining_offset = smem_offset; PrimExpr remaining_offset = smem_offset;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
multi_dim_indices.insert(multi_dim_indices.begin(), floormod(remaining_offset, shape[i])); multi_dim_indices.insert(multi_dim_indices.begin(),
floormod(remaining_offset, shape[i]));
remaining_offset = floordiv(remaining_offset, shape[i]); remaining_offset = floordiv(remaining_offset, shape[i]);
} }
auto forward_indices = layout_map_[load->buffer]->Forward(multi_dim_indices); auto forward_indices =
layout_map_[load->buffer]->Forward(multi_dim_indices);
PrimExpr new_offset = 0; PrimExpr new_offset = 0;
PrimExpr stride_offset = 1; PrimExpr stride_offset = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
...@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return access_ptr_call; return access_ptr_call;
} }
PrimExpr VisitExpr_(const tir::CallNode* op) final { PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (!op->op.same_as(builtin::ptx_ldmatrix()) && !op->op.same_as(builtin::mma_store())) { if (!op->op.same_as(builtin::ptx_ldmatrix()) &&
!op->op.same_as(builtin::mma_store())) {
return Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op)); return Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
} else { } else {
is_ptx_ = true; is_ptx_ = true;
...@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]); BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);
if (buffer_remap_.count(load->buffer)) { if (buffer_remap_.count(load->buffer)) {
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype); auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
auto new_call = call.CopyOnWrite(); auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr); new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
} }
} else if (call->op.same_as(builtin::mma_store())) { } else if (call->op.same_as(builtin::mma_store())) {
// because we will directly store result to Buffer instead of calling mma_store now // because we will directly store result to Buffer instead of calling
// mma_store now
auto access_ptr = call->args[2]; auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype); auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
auto new_call = call.CopyOnWrite(); auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr); new_call->args.Set(2, new_access_ptr);
} else { } else {
...@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return call; return call;
} }
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op)); auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (is_ptx_) { if (is_ptx_) {
return load; return load;
...@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return load; return load;
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op)); auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) { if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices); auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
...@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { ...@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return store; return store;
} }
PrimExpr VisitExpr_(const VarNode* op) final { PrimExpr VisitExpr_(const VarNode *op) final {
auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op)); auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (buffer_data_to_buffer_.count(var)) { if (buffer_data_to_buffer_.count(var)) {
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
if (buffer_remap_.count(buffer)) return buffer_remap_[buffer]->data; if (buffer_remap_.count(buffer))
return buffer_remap_[buffer]->data;
} }
return var; return var;
} }
Stmt VisitStmt_(const EvaluateNode* op) final { Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode* call = op->value.as<CallNode>(); const CallNode *call = op->value.as<CallNode>();
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (call && call->op.as<GlobalVarNode>()) if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op)); return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_); auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
if (tile_op == nullptr) return IRMutatorWithAnalyzer::VisitStmt_(op); if (tile_op == nullptr)
return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace = decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn"); auto workspace =
decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace); workspaces_.push_back(workspace);
return workspace.access_ptr(2); // write return workspace.access_ptr(2); // write
}; };
auto lowered = tile_op->Lower( auto lowered =
LowerArgs{target_, thread_block_size_, thread_var_, callback, layout_map_, buffer_remap_}, tile_op->Lower(LowerArgs{target_, thread_block_size_, thread_var_,
analyzer_); callback, layout_map_, buffer_remap_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U); ICHECK_NE(iv->thread_tag.length(), 0U);
...@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() { ...@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() {
} }
TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp); TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -38,22 +38,23 @@ using namespace tir; ...@@ -38,22 +38,23 @@ using namespace tir;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker_ : public StmtVisitor { class WarpSpecializedRoleMarker_ : public StmtVisitor {
public: public:
WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer) WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const { Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt); auto it = map_.find(stmt);
ICHECK(it != map_.end()); ICHECK(it != map_.end());
return it->second; return it->second;
} }
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final { void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer; Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer; role = Role::kProducer;
has_bulk_copy_ = true; has_bulk_copy_ = true;
} }
...@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { ...@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) { if (!is_shared_store) {
SetRole(op, Role::kConsumer); SetRole(op, Role::kConsumer);
return; return;
...@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { ...@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
break; break;
} }
} }
if (role == Role::kProducer) has_simt_copy_ = true; if (role == Role::kProducer)
has_simt_copy_ = true;
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const SeqStmtNode* op) final { void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]); auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) { for (auto stmt : op->seq) {
...@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor { ...@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const IfThenElseNode* op) final { void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case); auto role = GetRole(op->then_case);
if (op->else_case.defined()) { if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value()); auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth; if (role != role_else)
role = Role::kBoth;
} }
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const BlockRealizeNode* op) final { void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block)); SetRole(op, GetRole(op->block));
} }
template <class NodeType> template <class NodeType> void HandleBodyStmt(const NodeType *op) {
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body)); SetRole(op, GetRole(op->body));
} }
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; } bool HasSimtCopy() { return has_simt_copy_; }
private: private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_; std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
bool has_bulk_copy_ = false; bool has_bulk_copy_ = false;
}; };
class MultiVersionBufferRewriter : public StmtExprMutator { class MultiVersionBufferRewriter : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc& f) { static PrimFunc Substitute(PrimFunc &f) {
auto rewriter = MultiVersionBufferRewriter(); auto rewriter = MultiVersionBufferRewriter();
rewriter.buffer_lca_ = DetectBufferAccessLCA(f); rewriter.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : rewriter.buffer_lca_) { for (auto [buffer, _] : rewriter.buffer_lca_) {
...@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ...@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return f; return f;
} }
private: private:
MultiVersionBufferRewriter() = default; MultiVersionBufferRewriter() = default;
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt, Array<Buffer> scoped_buffers) { Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt,
Array<Buffer> scoped_buffers) {
std::vector<Role> roles; std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes; Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_); auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
for (auto stmt : seq_stmt) { for (auto stmt : seq_stmt) {
marker(stmt); marker(stmt);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"", /*body*/ stmt);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
reads.push_back(std::move(access[0])); reads.push_back(std::move(access[0]));
writes.push_back(std::move(access[1])); writes.push_back(std::move(access[1]));
roles.push_back(marker.GetRole(stmt)); roles.push_back(marker.GetRole(stmt));
} }
std::unordered_set<const BufferNode*> consumer_used, producer_used; std::unordered_set<const BufferNode *> consumer_used, producer_used;
for (size_t i = 0; i < seq_stmt.size(); i++) { for (size_t i = 0; i < seq_stmt.size(); i++) {
if (roles[i] == Role::kProducer) { if (roles[i] == Role::kProducer) {
for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get()); for (BufferRegion br : writes[i])
producer_used.insert(br->buffer.get());
} else { } else {
for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get()); for (BufferRegion br : reads[i])
consumer_used.insert(br->buffer.get());
} }
} }
Array<Buffer> versioned_buffers; Array<Buffer> versioned_buffers;
for (Buffer buffer : scoped_buffers) { for (Buffer buffer : scoped_buffers) {
if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) { if (consumer_used.count(buffer.get()) &&
producer_used.count(buffer.get())) {
versioned_buffers.push_back(buffer); versioned_buffers.push_back(buffer);
} }
} }
return versioned_buffers; return versioned_buffers;
} }
static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) { if (new_buffer->strides.size()) {
...@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ...@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return Buffer(new_buffer); return Buffer(new_buffer);
} }
Stmt VisitStmt_(const BlockRealizeNode* op) final { Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op)); BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
Block block = block_realize->block; Block block = block_realize->block;
Array<Buffer> alloc_buffers; Array<Buffer> alloc_buffers;
for (auto buffer : block->alloc_buffers) { for (auto buffer : block->alloc_buffers) {
...@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ...@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return block_realize; return block_realize;
} }
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
auto num_stages_anno = op->annotations.Get("num_stages"); auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op); if (!num_stages_anno.defined())
return StmtExprMutator::VisitStmt_(op);
ICHECK(num_stages_anno.as<IntImmNode>()); ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value); int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
const SeqStmtNode* pipeline_body_seq = op->body.as<SeqStmtNode>(); const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
<< "ValueError: The body of the software pipeline should be SeqStmt, got " "should be SeqStmt, got "
<< op->body->GetTypeKey(); << op->body->GetTypeKey();
Array<Buffer> scoped_buffers = {}; Array<Buffer> scoped_buffers = {};
for (auto [buffer, stmt] : buffer_lca_) { for (auto [buffer, stmt] : buffer_lca_) {
if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer); if (stmt.defined() && stmt.value().get() == op)
scoped_buffers.push_back(buffer);
} }
Array<Buffer> versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers); Array<Buffer> versioned_buffers =
GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
for (auto buffer : versioned_buffers) { for (auto buffer : versioned_buffers) {
Var buffer_var = buffer->data; Var buffer_var = buffer->data;
...@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ...@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return for_node; return for_node;
} }
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer); auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(load); return std::move(load);
} }
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto* n = load.CopyOnWrite(); auto *n = load.CopyOnWrite();
n->buffer = new_buffer; n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_); n->indices.insert(n->indices.begin(), version_index_);
return std::move(load); return std::move(load);
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer); auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(store); return std::move(store);
} }
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto* n = store.CopyOnWrite(); auto *n = store.CopyOnWrite();
n->buffer = new_buffer; n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_); n->indices.insert(n->indices.begin(), version_index_);
return std::move(store); return std::move(store);
} }
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) { if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1}); return RewriteBufferAccess(call, {1});
...@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator { ...@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return call; return call;
} }
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) { PrimExpr RewriteBufferAccess(const Call &call,
auto product = [](const Array<PrimExpr>& input) { const std::vector<int> arg_indices) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, auto product = [](const Array<PrimExpr> &input) {
make_const(DataType::Int(32), 1), input); return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
}; };
Array<PrimExpr> new_args = call->args; Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) { for (int i : arg_indices) {
auto buffer_var = Downcast<Var>(call->args[i]); auto buffer_var = Downcast<Var>(call->args[i]);
if (!buffer_data_to_buffer_.count(buffer_var)) continue; if (!buffer_data_to_buffer_.count(buffer_var))
const Buffer& buffer = buffer_data_to_buffer_[buffer_var]; continue;
const Buffer &buffer = buffer_data_to_buffer_[buffer_var];
auto it = buffer_remap_.find(buffer); auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) { if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1]; const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset; PrimExpr offset;
if (new_buffer->strides.empty()) { if (new_buffer->strides.empty()) {
offset = product(buffer->shape); offset = product(buffer->shape);
...@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() { ...@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() {
TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer") TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer")
.set_body_typed(MultiVersionBuffer); .set_body_typed(MultiVersionBuffer);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) { ...@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) {
return true; return true;
} }
} // namespace } // namespace
class PipelinePlanner : public StmtExprMutator { class PipelinePlanner : public StmtExprMutator {
public: public:
static Stmt Substitute(const PrimFunc& f) { static Stmt Substitute(const PrimFunc &f) {
PipelinePlanner substituter; PipelinePlanner substituter;
for (const auto& [_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer); substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Pipeline_Planning: Require the target attribute"; ICHECK(target.defined())
<< "Pipeline_Planning: Require the target attribute";
substituter.target_ = target.value(); substituter.target_ = target.value();
return substituter.VisitStmt(f->body); return substituter.VisitStmt(f->body);
} }
private: private:
PipelinePlanner() = default; PipelinePlanner() = default;
struct PipelineStageInfo { struct PipelineStageInfo {
...@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator {
}; };
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt); Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); /*body*/ stmt);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
PipelineStageInfo pinfo; PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]); pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]); pinfo.writes = std::move(access[1]);
...@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator {
// copy stage should only have one reads and one writes // copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) { if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
for (auto region : pinfo.reads) for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") pinfo.copy_stage = true; if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
for (auto region : pinfo.writes) for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") pinfo.copy_stage = true; if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
} }
return std::move(pinfo); return std::move(pinfo);
} }
Stmt VisitStmt_(const ForNode* loop) final { Stmt VisitStmt_(const ForNode *loop) final {
auto num_stages_anno = loop->annotations.Get("num_stages"); auto num_stages_anno = loop->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(loop); if (!num_stages_anno.defined())
return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno.as<IntImmNode>()->value; int num_stages = num_stages_anno.as<IntImmNode>()->value;
Stmt pipeline_body{nullptr}; Stmt pipeline_body{nullptr};
if (const auto* realize = loop->body.as<BlockRealizeNode>()) { if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
const auto& block = realize->block; const auto &block = realize->block;
for (const auto& buffer : block->alloc_buffers) { for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>()); ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
...@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator {
} else { } else {
pipeline_body = loop->body; pipeline_body = loop->body;
} }
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>(); const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
<< "ValueError: The body of the software pipeline should be SeqStmt, got " "should be SeqStmt, got "
<< loop->body->GetTypeKey(); << loop->body->GetTypeKey();
CHECK(num_stages >= 1); CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial); CHECK(loop->kind == ForKind::kSerial);
...@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator {
} }
// analysis use-def chain // analysis use-def chain
for (auto& pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1; i < static_cast<int>(pipeline_body_seq->size()); i++) { for (int i = pinfo.original_order + 1;
if (!pinfo.copy_stage) continue; i < static_cast<int>(pipeline_body_seq->size()); i++) {
for (const BufferRegion& read : pipeline_stage_infos[i].reads) { if (!pinfo.copy_stage)
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) { continue;
return r->buffer == read->buffer && MayConflict(r->region, read->region); for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
}) != pinfo.writes.end()) { if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
[&](const BufferRegion &r) {
return r->buffer == read->buffer &&
MayConflict(r->region, read->region);
}) != pinfo.writes.end()) {
pinfo.last_use_stage = std::max(pinfo.last_use_stage, i); pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
} }
} }
for (const BufferRegion& write : pipeline_stage_infos[i].writes) { for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) { if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
return r->buffer == write->buffer && MayConflict(r->region, write->region); [&](const BufferRegion &r) {
}) != pinfo.writes.end()) { return r->buffer == write->buffer &&
CHECK(false) << "Can't handle multiple write on overlap buffer region in the pipeline " MayConflict(r->region, write->region);
}) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer "
"region in the pipeline "
"planning pass: " "planning pass: "
<< pipeline_body_seq->seq[pinfo.original_order]; << pipeline_body_seq->seq[pinfo.original_order];
} }
...@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator {
// Making stages and orders // Making stages and orders
int order_idx = 0; int order_idx = 0;
for (auto& pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage && pinfo.last_use_stage != -1) continue; if (pinfo.copy_stage && pinfo.last_use_stage != -1)
continue;
pinfo.order = order_idx++; pinfo.order = order_idx++;
pinfo.stage = num_stages; pinfo.stage = num_stages;
for (auto& pinfo_1 : pipeline_stage_infos) { for (auto &pinfo_1 : pipeline_stage_infos) {
if (pinfo_1.copy_stage && pinfo_1.last_use_stage == pinfo.original_order) { if (pinfo_1.copy_stage &&
pinfo_1.last_use_stage == pinfo.original_order) {
pinfo_1.order = order_idx++; pinfo_1.order = order_idx++;
pinfo_1.stage = 0; pinfo_1.stage = 0;
} }
} }
} }
ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) << ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
"The number of stages should be equal to the number of pipeline stages. " << << "The number of stages should be equal to the number of pipeline "
"Got " << order_idx << " stages and " << pipeline_stage_infos.size() << " pipeline stages."; "stages. "
<< "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
<< " pipeline stages.";
// if all the copy is at the end of the order, we can move these copy to the beginning of the // if all the copy is at the end of the order, we can move these copy to the
// order and shrink the stage offset by 1. // beginning of the order and shrink the stage offset by 1.
int copy_stage_at_end = [&]() { int copy_stage_at_end = [&]() {
int copy_stage_cnt = 0; int copy_stage_cnt = 0;
int copy_order_min = pipeline_stage_infos.size(); int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0; int non_copy_order_max = 0;
for (auto& pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage) { if (pinfo.copy_stage) {
copy_stage_cnt++; copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order); copy_order_min = std::min(copy_order_min, pinfo.order);
...@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator {
non_copy_order_max = std::max(non_copy_order_max, pinfo.order); non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
} }
} }
if (copy_order_min > non_copy_order_max) return copy_stage_cnt; if (copy_order_min > non_copy_order_max)
return copy_stage_cnt;
return -1; return -1;
}(); }();
if (copy_stage_at_end > 0 && num_stages >= 2) { if (copy_stage_at_end > 0 && num_stages >= 2) {
for (auto& pinfo : pipeline_stage_infos) { // move copy to the beginning for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order = (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size(); pinfo.order =
if (!pinfo.copy_stage) pinfo.stage--; (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage)
pinfo.stage--;
} }
} }
// Finally, make the pipeline annotation // Finally, make the pipeline annotation
Map<String, ObjectRef> annotations; Map<String, ObjectRef> annotations;
for (const auto& [key, value] : loop->annotations) { for (const auto &[key, value] : loop->annotations) {
if (key != "num_stages") { if (key != "num_stages") {
annotations.Set(key, value); annotations.Set(key, value);
} }
...@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator {
std::vector<Integer> orders, stages; std::vector<Integer> orders, stages;
orders.reserve(pipeline_stage_infos.size()); orders.reserve(pipeline_stage_infos.size());
stages.reserve(pipeline_stage_infos.size()); stages.reserve(pipeline_stage_infos.size());
for (auto& pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
orders.push_back(pinfo.order); orders.push_back(pinfo.order);
stages.push_back(pinfo.stage); stages.push_back(pinfo.stage);
} }
...@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator {
annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages)); annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders)); annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
if (TargetHasAsyncCopy(target_)) if (TargetHasAsyncCopy(target_))
annotations.Set(tir::attr::software_pipeline_async_stages, Array<Integer>{0}); annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body, return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
loop->thread_binding, annotations); loop->thread_binding, annotations);
} }
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const auto& buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
return std::move(block); return std::move(block);
...@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator { ...@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator {
tvm::transform::Pass PipelinePlanning() { tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
PrimFuncNode* fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = PipelinePlanner::Substitute(f); fptr->body = PipelinePlanner::Substitute(f);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning").set_body_typed(PipelinePlanning); TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
.set_body_typed(PipelinePlanning);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
* \brief Remove useless parameters of TL PrimFunc. * \brief Remove useless parameters of TL PrimFunc.
*/ */
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include "tir/analysis/control_flow_graph.h" #include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> { ...@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") { TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities) TVM_ATTR_FIELD(transitively_prove_inequalities)
.describe( .describe("If true, simplify conditionals with transitive combinations "
"If true, simplify conditionals with transitive combinations of scoped constraints") "of scoped constraints")
.set_default(false); .set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional) TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
.describe( .describe("If true, known buffer values are propagated and used to "
"If true, known buffer values are propagated and used to statically prove conditionals") "statically prove conditionals")
.set_default(false); .set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions) TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
.describe( .describe("If true, known buffer values are propagated and used to "
"If true, known buffer values are propagated and used to replace BufferLoad wherever " "replace BufferLoad wherever "
"possible") "possible")
.set_default(false); .set_default(false);
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors) TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
...@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> { ...@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
.set_default(false); .set_default(false);
TVM_ATTR_FIELD(apply_constraints_to_boolean_branches) TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe( .describe("If true, simplify each branch of AND/OR "
"If true, simplify each branch of AND/OR " "under a constraints provided by the other branch")
"under a constraints provided by the other branch")
.set_default(false); .set_default(false);
} }
RewriteSimplifier::Extension GetEnabledExtensions() const { RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone; RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
if (transitively_prove_inequalities) { if (transitively_prove_inequalities) {
flags = flags = RewriteSimplifier::Extension(
RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities); flags | RewriteSimplifier::kTransitivelyProveInequalities);
} }
if (convert_boolean_to_and_of_ors) { if (convert_boolean_to_and_of_ors) {
flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs); flags = RewriteSimplifier::Extension(
flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
} }
if (apply_constraints_to_boolean_branches) { if (apply_constraints_to_boolean_branches) {
flags = RewriteSimplifier::Extension(flags | flags = RewriteSimplifier::Extension(
RewriteSimplifier::kApplyConstraintsToBooleanBranches); flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
} }
return flags; return flags;
} }
}; };
std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) { std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
struct Visitor : StmtExprVisitor { struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitStmt_;
Visitor(PrimFunc func) : func(func) {} Visitor(PrimFunc func) : func(func) {}
void VisitExpr_(const CallNode* op) override { void VisitExpr_(const CallNode *op) override {
for (const auto& arg: op->args) { for (const auto &arg : op->args) {
for (const auto& it: func->buffer_map) { for (const auto &it : func->buffer_map) {
if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) { if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
used_in_buffer_def_.insert(it.second.get()); used_in_buffer_def_.insert(it.second.get());
} }
}
} }
StmtExprVisitor::VisitExpr_(op); }
StmtExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const BufferLoadNode* op) override { void VisitExpr_(const BufferLoadNode *op) override {
VisitBuffer(op->buffer); VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const BufferStoreNode* op) override { void VisitStmt_(const BufferStoreNode *op) override {
VisitBuffer(op->buffer); VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const BlockNode* op) override { void VisitStmt_(const BlockNode *op) override {
for (const auto& buffer: op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
for (const auto& it: func->buffer_map) { for (const auto &it : func->buffer_map) {
if (it.second.get()->data.same_as(buffer.get()->data)) { if (it.second.get()->data.same_as(buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get()); used_in_buffer_def_.insert(it.second.get());
} }
}
} }
for (const auto& buffer: op->reads) { }
for (const auto& it: func->buffer_map) { for (const auto &buffer : op->reads) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { for (const auto &it : func->buffer_map) {
used_in_buffer_def_.insert(it.second.get()); if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
} used_in_buffer_def_.insert(it.second.get());
} }
} }
for (const auto& buffer: op->writes) { }
for (const auto& it: func->buffer_map) { for (const auto &buffer : op->writes) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) { for (const auto &it : func->buffer_map) {
used_in_buffer_def_.insert(it.second.get()); if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
} used_in_buffer_def_.insert(it.second.get());
} }
} }
StmtExprVisitor::VisitStmt_(op); }
StmtExprVisitor::VisitStmt_(op);
} }
void VisitBuffer(const Buffer& buf) { void VisitBuffer(const Buffer &buf) {
// Collect buffers that should remain defined // Collect buffers that should remain defined
VarUseDefAnalyzer usage(Array<Var>{}); VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data); usage(buf->data);
for (const auto& dim : buf->shape) { for (const auto &dim : buf->shape) {
usage(dim); usage(dim);
} }
for (const auto& dim : buf->strides) { for (const auto &dim : buf->strides) {
usage(dim); usage(dim);
} }
usage(buf->elem_offset); usage(buf->elem_offset);
for (const auto& buffer : usage.buffer_use_count_) { for (const auto &buffer : usage.buffer_use_count_) {
if (buffer.second >= 1) { if (buffer.second >= 1) {
used_in_buffer_def_.insert(buffer.first); used_in_buffer_def_.insert(buffer.first);
} }
} }
for (const auto& buffer : usage.undefined_buffers_) { for (const auto &buffer : usage.undefined_buffers_) {
used_in_buffer_def_.insert(buffer.get()); used_in_buffer_def_.insert(buffer.get());
} }
} }
PrimFunc func; PrimFunc func;
std::unordered_set<const BufferNode*> used_in_buffer_def_; std::unordered_set<const BufferNode *> used_in_buffer_def_;
}; };
Visitor visitor(func); Visitor visitor(func);
...@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) { ...@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
return visitor.used_in_buffer_def_; return visitor.used_in_buffer_def_;
} }
/* \brief Utility function to collect vars that should be retained. Used in
/* \brief Utility function to collect vars that should be retained. Used in Letstmt Only * Letstmt Only
*/ */
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) { std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
struct Visitor : StmtExprVisitor { struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitStmt_;
void VisitExpr_(const BufferLoadNode* op) override { void VisitExpr_(const BufferLoadNode *op) override {
VisitBuffer(op->buffer); VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
void VisitStmt_(const BufferStoreNode* op) override { void VisitStmt_(const BufferStoreNode *op) override {
VisitBuffer(op->buffer); VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitBuffer(const Buffer& buf) { void VisitBuffer(const Buffer &buf) {
// Collect variables that should remain defined // Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{}); VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data); usage(buf->data);
for (const auto& dim : buf->shape) { for (const auto &dim : buf->shape) {
usage(dim); usage(dim);
} }
for (const auto& dim : buf->strides) { for (const auto &dim : buf->strides) {
usage(dim); usage(dim);
} }
usage(buf->elem_offset); usage(buf->elem_offset);
// Track for use in LetStmtNode mutator // Track for use in LetStmtNode mutator
for (const auto& var : usage.undefined_) { for (const auto &var : usage.undefined_) {
used_in_buffer_def_.insert(var.get()); used_in_buffer_def_.insert(var.get());
} }
} }
std::unordered_set<const VarNode*> used_in_buffer_def_; std::unordered_set<const VarNode *> used_in_buffer_def_;
}; };
Visitor visitor; Visitor visitor;
...@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& ...@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt&
} }
class SimplifyConfig : public Attrs { class SimplifyConfig : public Attrs {
public: public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
}; };
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode); TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer { class StmtSimplifier : public IRMutatorWithAnalyzer {
public: public:
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) { Optional<SimplifyConfig> config_opt = NullOpt) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>()); auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions());
std::optional<ControlFlowGraph> touch_pattern = std::nullopt; std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional || if (config->propagate_knowns_to_prove_conditional ||
...@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
touch_pattern = ControlFlowGraph(func->body); touch_pattern = ControlFlowGraph(func->body);
} }
std::unordered_set<const VarNode*> used_in_buffer_def = std::unordered_set<const VarNode *> used_in_buffer_def =
CollectVarsUsedInBufferDefinition(func->body); CollectVarsUsedInBufferDefinition(func->body);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def)); std::move(used_in_buffer_def));
...@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Array<Var> new_params; Array<Var> new_params;
Map<Var, Buffer> new_buffer_map; Map<Var, Buffer> new_buffer_map;
// Check whether each buffer is used // Check whether each buffer is used
for (const auto& var: func->params) { for (const auto &var : func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) { if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != simplifier.used_buffers_.end()) { if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
new_params.push_back(var); simplifier.used_buffers_.end()) {
new_buffer_map.Set(var, func->buffer_map[var]); new_params.push_back(var);
} else { new_buffer_map.Set(var, func->buffer_map[var]);
param_updated = true; } else {
} param_updated = true;
} }
}
} }
// return func; // return func;
if (param_updated) { if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, new_buffer_map, func->attrs, func->span); return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else { } else {
return func; return func;
} }
} }
private: private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, explicit StmtSimplifier(
std::optional<ControlFlowGraph> touch_pattern, Analyzer *analyzer, SimplifyConfig config,
std::unordered_set<const VarNode*> used_in_buffer_def) std::optional<ControlFlowGraph> touch_pattern,
: IRMutatorWithAnalyzer(analyzer), std::unordered_set<const VarNode *> used_in_buffer_def)
config_(config), : IRMutatorWithAnalyzer(analyzer), config_(config),
touch_pattern_(touch_pattern), touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
used_in_buffer_def_(used_in_buffer_def) {} }
using Parent = IRMutatorWithAnalyzer; using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_; using Parent::VisitExpr_;
using Parent::VisitStmt; using Parent::VisitStmt;
using Parent::VisitStmt_; using Parent::VisitStmt_;
PrimExpr VisitExpr(const PrimExpr& expr) final { PrimExpr VisitExpr(const PrimExpr &expr) final {
if (config_->propagate_knowns_to_simplify_expressions) { if (config_->propagate_knowns_to_simplify_expressions) {
return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_); return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
analyzer_);
} else { } else {
return analyzer_->Simplify(expr); return analyzer_->Simplify(expr);
} }
...@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
Stmt VisitStmt(const Stmt& stmt) override { Stmt VisitStmt(const Stmt &stmt) override {
Optional<Stmt> cache = this->current_stmt_; Optional<Stmt> cache = this->current_stmt_;
this->current_stmt_ = stmt; this->current_stmt_ = stmt;
Stmt output = Parent::VisitStmt(stmt); Stmt output = Parent::VisitStmt(stmt);
...@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return output; return output;
} }
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); With<ConstraintContext> ctx2(analyzer_,
op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op); return Parent::VisitStmt_(op);
} }
bool CanInlineLetStmt(const LetStmtNode* op) { bool CanInlineLetStmt(const LetStmtNode *op) {
if (is_const_number(op->value)) return true; if (is_const_number(op->value))
if (op->value.as<VarNode>()) return true; return true;
if (op->value.as<VarNode>())
return true;
// Won't face the deep expression explosion problem as in Let expression. // Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index). // attempt to inline as much as possible if the value integer type(can be
if (!op->value.dtype().is_int()) return false; // index).
if (!op->value.dtype().is_int())
return false;
return SideEffect(op->value) <= CallEffectKind::kPure; return SideEffect(op->value) <= CallEffectKind::kPure;
} }
Stmt VisitStmt_(const LetStmtNode* op) override { Stmt VisitStmt_(const LetStmtNode *op) override {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
bool can_inline = CanInlineLetStmt(op); bool can_inline = CanInlineLetStmt(op);
if (can_inline) { if (can_inline) {
...@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} }
} }
Stmt VisitStmt_(const IfThenElseNode* op) override { Stmt VisitStmt_(const IfThenElseNode *op) override {
if (Optional<Bool> cond = ProveCondition(op->condition)) { if (Optional<Bool> cond = ProveCondition(op->condition)) {
if (cond.value()->value) { if (cond.value()->value) {
return this->VisitStmt(op->then_case); return this->VisitStmt(op->then_case);
...@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} }
} }
PrimExpr VisitExpr_(const CallNode* op) override { PrimExpr VisitExpr_(const CallNode *op) override {
if (op->op.same_as(builtin::if_then_else())) { if (op->op.same_as(builtin::if_then_else())) {
if (Optional<Bool> cond = ProveCondition(op->args[0])) { if (Optional<Bool> cond = ProveCondition(op->args[0])) {
if (cond.value()->value) { if (cond.value()->value) {
...@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Parent::VisitExpr_(op); return Parent::VisitExpr_(op);
} }
PrimExpr VisitExpr_(const VarNode* op) override { PrimExpr VisitExpr_(const VarNode *op) override {
used_vars_.insert(op); used_vars_.insert(op);
return Parent::VisitExpr_(op); return Parent::VisitExpr_(op);
} }
PrimExpr VisitExpr_(const BufferLoadNode* op) override { PrimExpr VisitExpr_(const BufferLoadNode *op) override {
auto buffer = op->buffer.get(); auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) { if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer); used_buffers_.insert(buffer);
} }
return Parent::VisitExpr_(op); return Parent::VisitExpr_(op);
} }
// eliminate useless stores // eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) override { Stmt VisitStmt_(const BufferStoreNode *op) override {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op)); BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) { if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) && if (load->buffer->data.same_as(store->buffer->data) &&
ArrayDeepEqual(load->indices, store->indices) && ArrayDeepEqual(load->indices, store->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && tir::ExprDeepEqual()(load->buffer->elem_offset,
store->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0); return Evaluate(0);
...@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} }
auto buffer = op->buffer.get(); auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) { if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer); used_buffers_.insert(buffer);
} }
return std::move(store); return std::move(store);
} }
private: private:
bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) { bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
if (lhs.size() != rhs.size()) { if (lhs.size() != rhs.size()) {
return false; return false;
} }
...@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
condition = Substitute(condition, non_inlined_bindings_); condition = Substitute(condition, non_inlined_bindings_);
if (config_->propagate_knowns_to_prove_conditional) { if (config_->propagate_knowns_to_prove_conditional) {
ICHECK(touch_pattern_.has_value()); ICHECK(touch_pattern_.has_value());
condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_); condition = touch_pattern_->SimplifyInContext(
condition, current_stmt_.value(), analyzer_);
} else { } else {
condition = analyzer_->Simplify(condition); condition = analyzer_->Simplify(condition);
} }
if (const int64_t* as_int = as_const_int(condition)) { if (const int64_t *as_int = as_const_int(condition)) {
return Bool(*as_int); return Bool(*as_int);
} else { } else {
return NullOpt; return NullOpt;
...@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { ...@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Map<Var, PrimExpr> non_inlined_bindings_; Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt}; Optional<Stmt> current_stmt_{NullOpt};
std::unordered_set<const VarNode*> used_in_buffer_def_; std::unordered_set<const VarNode *> used_in_buffer_def_;
std::unordered_set<const VarNode*> used_vars_; std::unordered_set<const VarNode *> used_vars_;
std::unordered_set<const BufferNode*> used_buffers_; std::unordered_set<const BufferNode *> used_buffers_;
}; };
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass Simplify() { tvm::transform::Pass Simplify() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify"); auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg); return StmtSimplifier::Apply(f, &analyzer, cfg);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify); TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);
......
...@@ -25,26 +25,28 @@ namespace tl { ...@@ -25,26 +25,28 @@ namespace tl {
using namespace tir; using namespace tir;
class ThreadPartialSyncPlanner : public StorageAccessVisitor { class ThreadPartialSyncPlanner : public StorageAccessVisitor {
public: public:
explicit ThreadPartialSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} explicit ThreadPartialSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
// The syncs inserted before each statement // The syncs inserted before each statement
std::unordered_set<const Object*> syncs_inserted_; std::unordered_set<const Object *> syncs_inserted_;
std::unordered_map<const Object*, int> partial_syncs_inserted_; std::unordered_map<const Object *, int> partial_syncs_inserted_;
protected: protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final { bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
return in_device_env() && scope == sync_scope_; return in_device_env() && scope == sync_scope_;
} }
// Plan the sync // Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final { std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
const ForNode *loop) final {
// Redirect all "shared.dyn" buffer access to the same buffer var // Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together. // so that the accesses can be planned together.
Var shared_dyn_buf; Var shared_dyn_buf;
for (StmtEntry& entry : seq) { for (StmtEntry &entry : seq) {
for (AccessEntry& access : entry.access) { for (AccessEntry &access : entry.access) {
if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" && if (access.scope.rank == StorageRank::kShared &&
access.buffer.defined()) { access.scope.tag == ".dyn" && access.buffer.defined()) {
if (!shared_dyn_buf.defined()) { if (!shared_dyn_buf.defined()) {
shared_dyn_buf = access.buffer; shared_dyn_buf = access.buffer;
} else { } else {
...@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop. // if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies // simulation based approach to find dependencies
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i]; const StmtEntry &s = seq[i];
// check if sync before statement is needed. // check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already. // Apply the syncs added already.
...@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
reads.clear(); reads.clear();
writes.clear(); writes.clear();
} }
for (const AccessEntry& acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) { if (FindConflict(writes, acc, false)) {
sync_before_stmt = true; sync_before_stmt = true;
...@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
writes.clear(); writes.clear();
} }
// Add the read/write of current statement // Add the read/write of current statement
for (const AccessEntry& acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
reads.push_back(acc); reads.push_back(acc);
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
...@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
} }
if (loop != nullptr) { if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i]; const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break; if (syncs_inserted_.count(s.stmt) != 0)
if (reads.empty() && writes.empty()) break; break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false; bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) { if (FindConflict(writes, acc, true)) {
sync_before_stmt = true; sync_before_stmt = true;
...@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
esync.type = kSync; esync.type = kSync;
esync.scope = sync_scope_; esync.scope = sync_scope_;
for (const StmtEntry& s : seq) { for (const StmtEntry &s : seq) {
if (syncs_inserted_.count(s.stmt)) { if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) { if (sync_count != 0) {
tail.clear(); tail.clear();
...@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
} }
++sync_count; ++sync_count;
} }
for (const AccessEntry& acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kSync) { if (acc.type == kSync) {
if (sync_count != 0) { if (sync_count != 0) {
tail.clear(); tail.clear();
...@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
head.insert(head.end(), tail.begin(), tail.end()); head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) { if (loop != nullptr) {
// clear double buffer flag after a loop is finished. // clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) { for (AccessEntry &e : head) {
e.double_buffer_write = false; e.double_buffer_write = false;
} }
} }
return head; return head;
} }
private: private:
// find conflicting entry in vec. // find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr, bool FindConflict(const std::vector<AccessEntry> &prev,
bool loop_carry) { const AccessEntry &curr, bool loop_carry) {
for (const AccessEntry& x : prev) { for (const AccessEntry &x : prev) {
if (FindConflict(x, curr, loop_carry)) { if (FindConflict(x, curr, loop_carry)) {
return true; return true;
} }
...@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return false; return false;
} }
bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) { bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Access to different buffers does not conflict. // Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) { if (!prev.buffer.same_as(curr.buffer)) {
return false; return false;
...@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// Even if access has the same index, those indices need to // Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition // depend on the innermost thread id to avoid race condition
bool depends_on_thread_index = true; bool depends_on_thread_index = true;
const VarNode* thread_index_var = nullptr; const VarNode *thread_index_var = nullptr;
if (!curr.threads.empty()) { if (!curr.threads.empty()) {
thread_index_var = curr.threads.back()->var.get(); thread_index_var = curr.threads.back()->var.get();
} }
for (size_t i = 0; i < prev.touched.size(); i++) { for (size_t i = 0; i < prev.touched.size(); i++) {
const auto& prev_intset = prev.touched[i]; const auto &prev_intset = prev.touched[i];
const auto& curr_intset = curr.touched[i]; const auto &curr_intset = curr.touched[i];
if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) { if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
PrimExpr prev_index = prev_intset.PointValue(); PrimExpr prev_index = prev_intset.PointValue();
PrimExpr curr_index = curr_intset.PointValue(); PrimExpr curr_index = curr_intset.PointValue();
has_same_index = ExprDeepEqual()(prev_index, curr_index); has_same_index = ExprDeepEqual()(prev_index, curr_index);
if (thread_index_var != nullptr) { if (thread_index_var != nullptr) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) { auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
return parameter == thread_index_var; return parameter == thread_index_var;
}; };
depends_on_thread_index = depends_on_thread_index && depends_on_thread_index = depends_on_thread_index &&
...@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return true; return true;
} }
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "kWarpSpecializationScope") { if (op->attr_key == "kWarpSpecializationScope") {
IfThenElse body = Downcast<IfThenElse>(op->body); IfThenElse body = Downcast<IfThenElse>(op->body);
auto partitions = Downcast<Array<IntImm>>(op->node); auto partitions = Downcast<Array<IntImm>>(op->node);
...@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
} }
} }
void insert_syncs(const Object* obj) { void insert_syncs(const Object *obj) {
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; // ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
if (syncs_inserted_.count(obj)) return; // condition";
if (syncs_inserted_.count(obj))
return;
if (num_partial_threads_.defined()) { if (num_partial_threads_.defined()) {
syncs_inserted_.insert(obj); syncs_inserted_.insert(obj);
partial_syncs_inserted_[obj] = static_cast<int>(num_partial_threads_.value()->value); partial_syncs_inserted_[obj] =
static_cast<int>(num_partial_threads_.value()->value);
} else { } else {
syncs_inserted_.insert(obj); syncs_inserted_.insert(obj);
} }
} }
private: private:
Optional<IntImm> num_partial_threads_; Optional<IntImm> num_partial_threads_;
// synchronization scope // synchronization scope
StorageScope sync_scope_; StorageScope sync_scope_;
}; };
// There are cases where necessary syncthreads is not inserted by ThreadPartialSyncInserter. // There are cases where necessary syncthreads is not inserted by
// For example, syncthreads is needed after async_wait_queue in the second loop below, // ThreadPartialSyncInserter. For example, syncthreads is needed after
// but since ThreadPartialSyncInserter is not aware of the asynchronous semantics, it cannot tell // async_wait_queue in the second loop below, but since
// that the syncthreads is needed there. // ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
// //
// // Pipeline prologue // // Pipeline prologue
// for i in range(125): // for i in range(125):
...@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor { ...@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// async_wait_queue(0, 2 - i): // async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4] // local[...] = shared[(i + 125) % 4]
class ThreadPartialSyncInserter : public StmtExprMutator { class ThreadPartialSyncInserter : public StmtExprMutator {
public: public:
ThreadPartialSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs, ThreadPartialSyncInserter(
std::unordered_map<const Object*, int> partial_syncs) StorageScope sync_scope, const std::unordered_set<const Object *> &syncs,
std::unordered_map<const Object *, int> partial_syncs)
: sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {} : sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}
Stmt VisitStmt(const Stmt& stmt) final { Stmt VisitStmt(const Stmt &stmt) final {
if (syncs_.size() == 0) return stmt; if (syncs_.size() == 0)
return stmt;
if (syncs_.count(stmt.get())) { if (syncs_.count(stmt.get())) {
Stmt barrier; Stmt barrier;
if (partial_syncs_.count(stmt.get())) { if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get()); auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared); ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(), {iter->second})); barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(),
{iter->second}));
} else { } else {
return StmtExprMutator::VisitStmt(stmt); return StmtExprMutator::VisitStmt(stmt);
} }
...@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator { ...@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
} }
} }
private: private:
// data structure. // data structure.
StorageScope sync_scope_; StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_; const std::unordered_set<const Object *> &syncs_;
const std::unordered_map<const Object*, int>& partial_syncs_; const std::unordered_map<const Object *, int> &partial_syncs_;
}; };
Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) { Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
...@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) { ...@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
ThreadPartialSyncPlanner planner(sync_scope); ThreadPartialSyncPlanner planner(sync_scope);
planner(stmt); planner(stmt);
return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_, return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(std::move(stmt)); planner.partial_syncs_inserted_)(
std::move(stmt));
} }
using namespace tir::transform; using namespace tir::transform;
...@@ -355,15 +367,16 @@ namespace transform { ...@@ -355,15 +367,16 @@ namespace transform {
Pass ThreadPartialSync(String storage_scope) { Pass ThreadPartialSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite(); auto *n = f.CopyOnWrite();
n->body = tl::ThreadPartialSync(std::move(n->body), storage_scope); n->body = tl::ThreadPartialSync(std::move(n->body), storage_scope);
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync").set_body_typed(ThreadPartialSync); TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
.set_body_typed(ThreadPartialSync);
} // namespace transform } // namespace transform
} // namespace tir } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -38,22 +38,23 @@ using namespace tir; ...@@ -38,22 +38,23 @@ using namespace tir;
enum class Role { kConsumer, kProducer, kBoth }; enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker : public StmtVisitor { class WarpSpecializedRoleMarker : public StmtVisitor {
public: public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer) WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {} : buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const { Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt); auto it = map_.find(stmt);
ICHECK(it != map_.end()); ICHECK(it != map_.end());
return it->second; return it->second;
} }
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); } Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final { void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer; Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer; role = Role::kProducer;
has_bulk_copy_ = true; has_bulk_copy_ = true;
} }
...@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor { ...@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const BufferStoreNode* op) final { void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared"; bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) { if (!is_shared_store) {
SetRole(op, Role::kConsumer); SetRole(op, Role::kConsumer);
return; return;
...@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor { ...@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
break; break;
} }
} }
if (role == Role::kProducer) has_simt_copy_ = true; if (role == Role::kProducer)
has_simt_copy_ = true;
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const SeqStmtNode* op) final { void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]); auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) { for (auto stmt : op->seq) {
...@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor { ...@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const IfThenElseNode* op) final { void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case); auto role = GetRole(op->then_case);
if (op->else_case.defined()) { if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value()); auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth; if (role != role_else)
role = Role::kBoth;
} }
SetRole(op, role); SetRole(op, role);
} }
void VisitStmt_(const BlockRealizeNode* op) final { void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block)); SetRole(op, GetRole(op->block));
} }
template <class NodeType> template <class NodeType> void HandleBodyStmt(const NodeType *op) {
void HandleBodyStmt(const NodeType* op) {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body)); SetRole(op, GetRole(op->body));
} }
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); } void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; } bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; } bool HasSimtCopy() { return has_simt_copy_; }
private: private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; } void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_; std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
bool has_bulk_copy_ = false; bool has_bulk_copy_ = false;
}; };
...@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) { ...@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
} }
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) { static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes}); auto call = Call(DataType::Handle(), MBarrierExpectTX(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call); return Evaluate(call);
} }
static Stmt makeArriveBarrier(PrimExpr barrier_id) { static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)}); auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
{makeGetBarrier(barrier_id)});
return Evaluate(call); return Evaluate(call);
} }
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) { static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
auto call = auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)}); {makeGetBarrier(barrier_id)});
return Evaluate(call); return Evaluate(call);
} }
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity}); auto call = Call(DataType::Handle(), MBarrierWaitParity(),
{makeGetBarrier(barrier_id), parity});
return Evaluate(call); return Evaluate(call);
} }
...@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) { ...@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
// } // }
class ProducerTraitsCollector : public StmtExprVisitor { class ProducerTraitsCollector : public StmtExprVisitor {
public: public:
ProducerTraitsCollector() { Clear(); } ProducerTraitsCollector() { Clear(); }
void Clear() { void Clear() {
...@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor { ...@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor {
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; } PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private: private:
void VisitExpr_(const CallNode* call) final { void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]); Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
...@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor { ...@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(call); StmtExprVisitor::VisitExpr_(call);
} }
void VisitStmt_(const ForNode* op) final { void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents; PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent; loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents; loop_extents = old_loop_evtents;
} }
void VisitExpr_(const BufferLoadNode* op) final { void VisitExpr_(const BufferLoadNode *op) final {
has_simt_copy = true; has_simt_copy = true;
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
...@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor { ...@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor {
// Rewrite the producer Stmt to use the correct barrier index // Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator { class MbarrierRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) { static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
MbarrierRewriter rewriter; MbarrierRewriter rewriter;
rewriter.producer_barrier_idx_ = barrier_id; rewriter.producer_barrier_idx_ = barrier_id;
return rewriter(stmt); return rewriter(stmt);
} }
private: private:
PrimExpr VisitExpr_(const CallNode* op) final { PrimExpr VisitExpr_(const CallNode *op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) { if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]); Call access_ptr = Downcast<Call>(call->args[2]);
...@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator { ...@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator {
PrimExpr producer_barrier_idx_; PrimExpr producer_barrier_idx_;
}; };
class ThreadIdxRewriter : public StmtExprMutator { class ThreadIdxRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) { static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
auto rewriter = ThreadIdxRewriter(thread_var, replaced); auto rewriter = ThreadIdxRewriter(thread_var, replaced);
return rewriter(stmt); return rewriter(stmt);
} }
private: private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced) ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
: thread_var_(thread_var), replaced_(replaced) {} : thread_var_(thread_var), replaced_(replaced) {}
PrimExpr VisitExpr_(const VarNode* var) final { PrimExpr VisitExpr_(const VarNode *var) final {
if (var == thread_var_.get()) { if (var == thread_var_.get()) {
return replaced_; return replaced_;
} else { } else {
...@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator { ...@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator {
PrimExpr replaced_; PrimExpr replaced_;
}; };
Block MakeGroupBlock(const Stmt& stmt, const Map<String, ObjectRef>& annotations) { Block MakeGroupBlock(const Stmt &stmt,
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt, const Map<String, ObjectRef> &annotations) {
/*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations); Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ stmt,
/*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
/*annotations=*/annotations);
return block; return block;
} }
...@@ -280,11 +288,8 @@ struct PipelineInfo { ...@@ -280,11 +288,8 @@ struct PipelineInfo {
std::vector<OpInfo> op_infos; std::vector<OpInfo> op_infos;
PipelineInfo() = default; PipelineInfo() = default;
PipelineInfo( PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
Array<Array<Integer>> group_info, Array<Integer> stage_info) {
Array<Integer> order_info,
Array<Integer> stage_info
) {
int n = static_cast<int>(group_info.size()); int n = static_cast<int>(group_info.size());
ICHECK(n == static_cast<int>(order_info.size())); ICHECK(n == static_cast<int>(order_info.size()));
ICHECK(n == static_cast<int>(stage_info.size())); ICHECK(n == static_cast<int>(stage_info.size()));
...@@ -301,7 +306,7 @@ struct PipelineInfo { ...@@ -301,7 +306,7 @@ struct PipelineInfo {
} }
} }
PipelineInfo(const PipelineInfo& other) { PipelineInfo(const PipelineInfo &other) {
for (auto op_info : other.op_infos) { for (auto op_info : other.op_infos) {
op_infos.push_back(op_info); op_infos.push_back(op_info);
} }
...@@ -364,18 +369,19 @@ struct PipelineInfo { ...@@ -364,18 +369,19 @@ struct PipelineInfo {
void PrintPipelineInfo() { void PrintPipelineInfo() {
std::cout << "Print op_infos:" << std::endl; std::cout << "Print op_infos:" << std::endl;
for (size_t i = 0; i < op_infos.size(); i++) { for (size_t i = 0; i < op_infos.size(); i++) {
std::cout << i << " " << op_infos[i].group_size << " " << op_infos[i].order << " " << op_infos[i].stage << std::endl; std::cout << i << " " << op_infos[i].group_size << " "
<< op_infos[i].order << " " << op_infos[i].stage << std::endl;
} }
std::cout << "End of print" << std::endl; std::cout << "End of print" << std::endl;
} }
}; };
class GroupOpRewriter : public StmtExprMutator { class GroupOpRewriter : public StmtExprMutator {
public: public:
GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {} GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}
private: private:
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
Map<String, ObjectRef> annotations; Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1)); annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>(); auto original_node = (op->body).as<SeqStmtNode>();
...@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator { ...@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator {
Array<Stmt> new_body; Array<Stmt> new_body;
int cur_id = 0; int cur_id = 0;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) { for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
if (pipeline_info_.op_infos[i].group_size == 0) continue; if (pipeline_info_.op_infos[i].group_size == 0)
continue;
Array<Stmt> block_stmt; Array<Stmt> block_stmt;
for (int j = 0; j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) { for (int j = 0;
j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
// ICHECK(group_info_[i][j].as<IntImmNode>()); // ICHECK(group_info_[i][j].as<IntImmNode>());
// int index = static_cast<int>(group_info_[i][j].as<IntImmNode>()->value); // int index =
// static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
ICHECK(original_node->seq[cur_id].as<BlockNode>()); ICHECK(original_node->seq[cur_id].as<BlockNode>());
auto block = original_node->seq[cur_id].as<BlockNode>(); auto block = original_node->seq[cur_id].as<BlockNode>();
// TODO: handle nested seqstmt // TODO: handle nested seqstmt
block_stmt.push_back(block->body); block_stmt.push_back(block->body);
cur_id++; cur_id++;
} }
new_body.push_back( new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); ? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
} }
Array<Integer> order_anno; Array<Integer> order_anno;
Array<Integer> stage_anno; Array<Integer> stage_anno;
...@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator { ...@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator {
for_annotations.erase("tl_pipeline_group"); for_annotations.erase("tl_pipeline_group");
for_annotations.Set("software_pipeline_order", order_anno); for_annotations.Set("software_pipeline_order", order_anno);
for_annotations.Set("software_pipeline_stage", stage_anno); for_annotations.Set("software_pipeline_stage", stage_anno);
For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, for_annotations); For new_for =
For(op->loop_var, op->min, op->extent, op->kind,
new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
op->thread_binding, for_annotations);
return new_for; return new_for;
} }
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
}; };
class WSCodeEmitter : public StmtMutator { class WSCodeEmitter : public StmtMutator {
public: public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker) Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker)
: is_emitting_producer_(is_emitting_producer), : is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
marker_(marker),
thread_var_(thread_iv->var) {} thread_var_(thread_iv->var) {}
private: private:
template <typename NodeType> template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Stmt FilterByRole(const NodeType* op) {
Role role = marker_.GetRole(op); Role role = marker_.GetRole(op);
if (role == Role::kBoth) if (role == Role::kBoth)
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
...@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator { ...@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator {
} }
// TODO: only need to add block for ops in the loop // TODO: only need to add block for ops in the loop
Stmt VisitStmt_(const SeqStmtNode* op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false; bool has_producer = false;
for (auto stmt : op->seq) { for (auto stmt : op->seq) {
if (marker_.GetRole(stmt) == Role::kProducer) { if (marker_.GetRole(stmt) == Role::kProducer) {
...@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator { ...@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator {
break; break;
} }
} }
bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth; bool need_producer_sync =
if (!need_producer_sync) return FilterByRole(op); has_producer && marker_.GetRole(op) == Role::kBoth;
if (!need_producer_sync)
return FilterByRole(op);
auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); }); auto seq_transformed =
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq); auto map = ExtractSyncPattern(op->seq);
// std::cout << "Print ExtractSyncPattern" << std::endl; // std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { // for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl; // std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// << map.release_after[i] << std::endl;
// } // }
// std::cout << "Print sync pattern" << std::endl; // std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) { // for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; // std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// } // }
// std::cout << "End of ExtractSyncPattern" << std::endl; // std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo(); // pipeline_info_.PrintPipelineInfo();
...@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator { ...@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator {
Map<String, ObjectRef> annotations; Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1)); annotations.Set(String("stmt_group"), Integer(1));
if (is_emitting_producer_) { // producer case if (is_emitting_producer_) { // producer case
ProducerTraitsCollector collector; ProducerTraitsCollector collector;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {}; Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue; if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
continue;
if (marker_.GetRole(op->seq[i]) == Role::kBoth) { if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
block_stmt.push_back(seq_transformed[i]); block_stmt.push_back(seq_transformed[i]);
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); new_body.push_back(MakeGroupBlock(
block_stmt.size() == 1 ? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
continue; continue;
} }
if (map.acquire[i] != -1) { if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; PrimExpr acquire_barrier_id =
PrimExpr parity = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; PrimExpr parity = map.is_loop_dependency(map.acquire[i])
? bitwise_xor(parity_, 1)
: parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
} }
ICHECK(map.release[i] >= 0); ICHECK(map.release[i] >= 0);
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; PrimExpr release_barrier_id =
auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); stage_ + num_barriers_ + num_stages_ * map.release[i];
auto stmt =
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt); collector.Collect(stmt);
if (!is_zero(collector.BulkCopyBytes())) { if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = IfThenElse(EQ(thread_var_, 0), auto expect_tx = IfThenElse(
makeExpectTX(release_barrier_id, collector.BulkCopyBytes())); EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(expect_tx); block_stmt.push_back(expect_tx);
} }
block_stmt.push_back(stmt); block_stmt.push_back(stmt);
...@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator { ...@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator {
if (map.release_after[i]) { if (map.release_after[i]) {
block_stmt.push_back(makeArriveBarrier(release_barrier_id)); block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) { for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); released_barrier_.insert(j + num_barriers_ +
num_stages_ * map.release[i]);
} }
} }
collector.Clear(); collector.Clear();
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
} }
} else { // consumer case } else { // consumer case
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {}; Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue; if (marker_.GetRole(op->seq[i]) == Role::kProducer)
continue;
if (map.acquire[i] != -1) { if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i]; PrimExpr acquire_barrier_id =
PrimExpr parity = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_; PrimExpr parity = map.is_loop_dependency(map.acquire[i])
? bitwise_xor(parity_, 1)
: parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
} }
block_stmt.push_back(seq_transformed[i]); block_stmt.push_back(seq_transformed[i]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); // new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ?
// block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
if (map.release_after[i]) { if (map.release_after[i]) {
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i]; PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.release[i];
block_stmt.push_back(makeArriveBarrier(release_barrier_id)); block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) { for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]); released_barrier_.insert(j + num_barriers_ +
num_stages_ * map.release[i]);
} }
// Update the pipeline info // Update the pipeline info
// Todo: handle sync // Todo: handle sync
} }
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
} }
// Filter out the producer stmts // Filter out the producer stmts
int cur_id = 0; int cur_id = 0;
PipelineInfo new_pipeline_info; PipelineInfo new_pipeline_info;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) { for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
i++) {
auto op_info = pipeline_info_.op_infos[i]; auto op_info = pipeline_info_.op_infos[i];
bool is_producer = false; bool is_producer = false;
for (int j = 0; j < op_info.group_size; j++) { for (int j = 0; j < op_info.group_size; j++) {
...@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator { ...@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator {
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)); return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
} }
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode *op) final {
int num_stages = 1; int num_stages = 1;
auto num_stages_anno = op->annotations.Get("num_stages"); auto num_stages_anno = op->annotations.Get("num_stages");
if (num_stages_anno.defined()) { if (num_stages_anno.defined()) {
...@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator { ...@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator {
Array<Array<Integer>> group_info_array; Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array; Array<Integer> order_info_array;
Array<Integer> stage_info_array; Array<Integer> stage_info_array;
auto group_anno = op->annotations.Get("tl_pipeline_group"); auto group_anno = op->annotations.Get("tl_pipeline_group");
if (group_anno.defined()) { if (group_anno.defined()) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno); group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
...@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator { ...@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator {
stage_info_array = Downcast<Array<Integer>>(stage_anno); stage_info_array = Downcast<Array<Integer>>(stage_anno);
} }
PipelineInfo pipeline_info(group_info_array, order_info_array, stage_info_array); PipelineInfo pipeline_info(group_info_array, order_info_array,
stage_info_array);
if (pipeline_info.op_infos.size() > 0) { if (pipeline_info.op_infos.size() > 0) {
ICHECK(pipeline_info_.op_infos.size() == 0) << "Nested pipeline not supported."; ICHECK(pipeline_info_.op_infos.size() == 0)
<< "Nested pipeline not supported.";
} }
PrimExpr parity_before = std::move(parity_); PrimExpr parity_before = std::move(parity_);
...@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator { ...@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator {
num_stages_ = num_stages; num_stages_ = num_stages;
pipeline_info_ = pipeline_info; pipeline_info_ = pipeline_info;
stage_ = FloorMod(op->loop_var - op->min, num_stages); stage_ = FloorMod(op->loop_var - op->min, num_stages);
parity_ = parity_ = FloorMod(parity_before * op->extent +
FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2); FloorDiv(op->loop_var - op->min, num_stages),
2);
auto result = FilterByRole(op); auto result = FilterByRole(op);
Stmt grouped_for_node; Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno.defined() && group_info_array.size() > 0 && !is_emitting_producer_) { if (result.as<ForNode>() && group_anno.defined() &&
group_info_array.size() > 0 && !is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_); GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result); auto for_node = Downcast<For>(result);
grouped_for_node = group_op_rewriter(for_node); grouped_for_node = group_op_rewriter(for_node);
...@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator { ...@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator {
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order"); for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage"); for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
} }
if (is_emitting_producer_ || !group_anno.defined() ||group_info_array.size() == 0) { if (is_emitting_producer_ || !group_anno.defined() ||
group_info_array.size() == 0) {
return for_node; return for_node;
} }
return grouped_for_node; return grouped_for_node;
...@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator { ...@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator {
return result; return result;
} }
Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); } Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BlockNode* op) final { Stmt VisitStmt_(const BlockNode *op) final {
ICHECK(0); ICHECK(0);
return Stmt(); return Stmt();
} }
Stmt VisitStmt_(const BlockRealizeNode* op) final { Stmt VisitStmt_(const BlockRealizeNode *op) final {
ICHECK(0); ICHECK(0);
return Stmt(); return Stmt();
} }
...@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator { ...@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator {
} }
}; };
std::vector<SyncPattern> CreateBaseSyncPairs(Array<Stmt> seq_stmt, std::vector<SyncPattern>
const std::vector<bool>& is_producer) { CreateBaseSyncPairs(Array<Stmt> seq_stmt,
const std::vector<bool> &is_producer) {
const int n = seq_stmt.size(); const int n = seq_stmt.size();
std::vector<std::set<const BufferNode*>> reads, writes; std::vector<std::set<const BufferNode *>> reads, writes;
reads.reserve(n); reads.reserve(n);
writes.reserve(n); writes.reserve(n);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"",
/*body*/ seq_stmt[i]); /*body*/ seq_stmt[i]);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_); auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode*> read_set, write_set; std::set<const BufferNode *> read_set, write_set;
for (auto region : access[0]) read_set.insert(region->buffer.get()); for (auto region : access[0])
for (auto region : access[1]) write_set.insert(region->buffer.get()); read_set.insert(region->buffer.get());
for (auto region : access[1])
write_set.insert(region->buffer.get());
reads.push_back(std::move(read_set)); reads.push_back(std::move(read_set));
writes.push_back(std::move(write_set)); writes.push_back(std::move(write_set));
} }
auto intersect_fn = [](const std::set<const BufferNode*>& lhs, auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
const std::set<const BufferNode*>& rhs) { const std::set<const BufferNode *> &rhs) {
for (auto ptr : lhs) for (auto ptr : lhs)
if (rhs.count(ptr)) return true; if (rhs.count(ptr))
return true;
return false; return false;
}; };
...@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator { ...@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) { for (int j = i + 1; j < n; j++) {
if (is_producer[i] != is_producer[j] && if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { (intersect_fn(writes[i], reads[j]) ||
intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j}); sync_patterns.push_back({i, j});
break; break;
} }
...@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator { ...@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++) { for (int j = 0; j < i; j++) {
if (is_producer[i] != is_producer[j] && if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) { (intersect_fn(writes[i], reads[j]) ||
intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j}); sync_patterns.push_back({i, j});
break; break;
} }
...@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator { ...@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator {
return sync_patterns; return sync_patterns;
} }
static std::vector<SyncPattern> RemoveUnusedSyncPatterns( static std::vector<SyncPattern>
const std::vector<SyncPattern>& sync_patterns, const std::vector<bool>& is_producer) { RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
const std::vector<bool> &is_producer) {
/* /*
Simplify multiple release-acquire pairs into one Simplify multiple release-acquire pairs into one
------------------ ------------------
...@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator { ...@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator {
std::vector<SyncPattern> sync_pattern_cleaned; std::vector<SyncPattern> sync_pattern_cleaned;
sync_pattern_cleaned.reserve(M); sync_pattern_cleaned.reserve(M);
for (int i = 0; i < M; i++) for (int i = 0; i < M; i++)
if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]); if (!removed[i])
sync_pattern_cleaned.push_back(sync_patterns[i]);
return sync_pattern_cleaned; return sync_pattern_cleaned;
} }
...@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator { ...@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator {
} }
auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer); auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer); auto sync_patterns =
RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
// for (auto pattern : sync_patterns) { // for (auto pattern : sync_patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl; // std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// } // }
SyncPatternMap map; SyncPatternMap map;
...@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator { ...@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator {
const bool is_emitting_producer_; const bool is_emitting_producer_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_set<int> released_barrier_; std::unordered_set<int> released_barrier_;
const WarpSpecializedRoleMarker& marker_; const WarpSpecializedRoleMarker &marker_;
int num_barriers_ = 0; int num_barriers_ = 0;
PrimExpr parity_ = 0; PrimExpr parity_ = 0;
...@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator { ...@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator {
}; };
class WarpSpecializedRewriter : public StmtExprMutator { class WarpSpecializedRewriter : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc f) { static PrimFunc Substitute(PrimFunc f) {
auto T = WarpSpecializedRewriter(); auto T = WarpSpecializedRewriter();
T.buffer_lca_ = DetectBufferAccessLCA(f); T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer); for (auto [buffer, _] : T.buffer_lca_)
T.buffer_data_to_buffer_.Set(buffer->data, buffer);
f.CopyOnWrite()->body = T(f->body); f.CopyOnWrite()->body = T(f->body);
return f; return f;
} }
private: private:
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent && if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") { Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_iv_ = Downcast<IterVar>(op->node); thread_iv_ = Downcast<IterVar>(op->node);
...@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator { ...@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator {
} }
} }
// If users define a thread binding, we will replace the thread binding with threadIdx.x // If users define a thread binding, we will replace the thread binding with
// We require the thread binding is threadIdx.x, and the extent is the same as the thread extent // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
Stmt VisitStmt_(const ForNode* op) final { // the same as the thread extent
Stmt VisitStmt_(const ForNode *op) final {
ICHECK(thread_iv_.defined()); ICHECK(thread_iv_.defined());
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op)); For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (for_node->kind == ForKind::kThreadBinding) { if (for_node->kind == ForKind::kThreadBinding) {
...@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator { ...@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
String thread_tag = for_node->thread_binding.value()->thread_tag; String thread_tag = for_node->thread_binding.value()->thread_tag;
ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x"; ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
Var thread_iv = Downcast<Var>(for_node->loop_var); Var thread_iv = Downcast<Var>(for_node->loop_var);
Stmt new_body = ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_); Stmt new_body =
ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
return new_body; return new_body;
} }
return for_node; return for_node;
} }
Stmt VisitStmt_(const BlockRealizeNode* op) final { Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op)); BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
if (!thread_iv_.defined()) { if (!thread_iv_.defined()) {
return block_realize; return block_realize;
} }
...@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator { ...@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case // Need one warp-group for bulk-copy only case
if (!marker.HasSimtCopy()) producer_thread_extent = 128; if (!marker.HasSimtCopy())
producer_thread_extent = 128;
// TODO: estimate the correct reg usage. // TODO: estimate the correct reg usage.
auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1})); auto inc_reg_stmt =
auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0})); Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1}));
auto dec_reg_stmt =
Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0}));
producer_code = SeqStmt({dec_reg_stmt, producer_code}); producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code}); consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var, producer_code =
thread_iv_->var - consumer_thread_extent); ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent);
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
need_update_thread_extent_ = true; need_update_thread_extent_ = true;
...@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator { ...@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
Array<PrimExpr> barrier_num_threads; Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers); barrier_num_threads.reserve(num_barriers);
for (int i = 0; i < num_barriers; i++) { for (int i = 0; i < num_barriers; i++) {
PrimExpr arrive_thread_count = PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent; ? producer_thread_extent
: consumer_thread_extent;
barrier_num_threads.push_back(arrive_thread_count); barrier_num_threads.push_back(arrive_thread_count);
} }
Stmt init_barrier = Stmt init_barrier = Evaluate(Call(
Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads)); DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
Stmt body = Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code); producer_code, consumer_code);
// Add an attr here to handle the partial thread count in THreadSync pass. // Add an attr here to handle the partial thread count in THreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent), Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)}; Downcast<IntImm>(consumer_thread_extent)};
...@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized").set_body_typed(WarpSpecialized); TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized")
.set_body_typed(WarpSpecialized);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
......
...@@ -30,13 +30,11 @@ def matmul( ...@@ -30,13 +30,11 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn(): ...@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn():
def test_gemm_i8i8i32_nn(): def test_gemm_i8i8i32_nn():
run_gemm( run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64)
512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64
)
def test_gemm_f16f16f16_tn(): def test_gemm_f16f16f16_tn():
...@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn(): ...@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn():
def test_gemm_f64f64f64_nt(): def test_gemm_f64f64f64_nt():
run_gemm( run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16)
512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16
)
def test_gemm_f32f32f32_nt(): def test_gemm_f32f32f32_nt():
......
...@@ -10,8 +10,7 @@ import tilelang as TL ...@@ -10,8 +10,7 @@ import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter, TensorCoreIntrinEmitter,)
)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -6,6 +6,7 @@ import tilelang.testing ...@@ -6,6 +6,7 @@ import tilelang.testing
import tilelang as tl import tilelang as tl
from tilelang import primitives as P from tilelang import primitives as P
def matmul_ssr( def matmul_ssr(
M, M,
N, N,
...@@ -30,13 +31,11 @@ def matmul_ssr( ...@@ -30,13 +31,11 @@ def matmul_ssr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -145,13 +144,11 @@ def matmul_rsr( ...@@ -145,13 +144,11 @@ def matmul_rsr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype) A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
...@@ -264,13 +261,11 @@ def matmul_rrr( ...@@ -264,13 +261,11 @@ def matmul_rrr(
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer(A_shape, in_dtype), A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype), B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype), C: T.Buffer((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype) A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......
...@@ -11,8 +11,7 @@ from .mma_macro_generator import ( ...@@ -11,8 +11,7 @@ from .mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
) )
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401 from .mma_layout import make_mma_swizzle_layout # noqa: F401
from .mfma_layout import make_mfma_swizzle_layout # noqa: F401 from .mfma_layout import make_mfma_swizzle_layout # noqa: F401
...@@ -14,6 +14,7 @@ from .utils import ( ...@@ -14,6 +14,7 @@ from .utils import (
lift = convert lift = convert
# TODO(lei): Add Typing for this file # TODO(lei): Add Typing for this file
class TensorCoreIntrinEmitter(object): class TensorCoreIntrinEmitter(object):
""" """
...@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object): ...@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object):
self.reduce_k = reduce_k self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte self.num_elems_per_byte = num_elems_per_byte
if self.warp_rows == 0 or self.warp_cols == 0: if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}") raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"): def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str): if isinstance(a_dtype, str):
...@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object): ...@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object):
A_local_buf.data, A_local_buf.data,
k_inner * warp_rows * local_size_a + i * local_size_a, k_inner * warp_rows * local_size_a + i * local_size_a,
B_local_buf.data, B_local_buf.data,
k_inner * warp_cols * local_size_b + j * local_size_b k_inner * warp_cols * local_size_b + j * local_size_b + lift(local_size_b) // 2,
+ lift(local_size_b) // 2,
C_local_buf.data, C_local_buf.data,
i * warp_cols * local_size_out i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
+ j * local_size_out
+ lift(local_size_out) // 2,
T.bool(False), T.bool(False),
) )
...@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object): ...@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object):
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings) return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings)) if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings))
def make_mma_load_layout(self, local_buf: Buffer, matrix:Literal["A", "B"]="A") -> T.Fragment: def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object): ...@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object):
elif matrix == "A" and not transposed: elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
else: else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") raise ValueError(
"ldmatrix only supports B transposed and A non-transposed for int8")
else: else:
raise ValueError(f"Unsupported dtype {dtype}") raise ValueError(f"Unsupported dtype {dtype}")
shape = local_buf.shape shape = local_buf.shape
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(local_buf.scope()) assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
local_buf.scope())
if matrix == "A": if matrix == "A":
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_k micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_k
...@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object):
transform_func = transform_func if not transposed else transform_func_trans transform_func = transform_func if not transposed else transform_func_trans
warp_size, local_size_a, local_size_b = self.WARP_SIZE, self.local_size_a, self.local_size_b warp_size, local_size_a, local_size_b = self.WARP_SIZE, self.local_size_a, self.local_size_b
local_size = local_size_a if matrix == "A" else local_size_b local_size = local_size_a if matrix == "A" else local_size_b
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32").inverse([warp_size, local_size]) inverse_mma_load_layout = IndexMap.from_func(
transform_func, index_dtype="int32").inverse([warp_size, local_size])
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
...@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object): ...@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object):
""" """
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, ( block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols # the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, ( warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_load_layout.map_indices([mma_i, mma_j]) lane_id, _ = inverse_mma_load_layout.map_indices([mma_i, mma_j])
if is_m_first: if is_m_first:
thread_id = ( thread_id = (
block_i * (block_col_warps * warp_cols) block_i * (block_col_warps * warp_cols) + block_j * warp_rows +
+ block_j * warp_rows warp_i * warp_cols + warp_j)
+ warp_i * warp_cols
+ warp_j
)
else: else:
thread_id = ( thread_id = (
block_j * (block_row_warps * warp_size) block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id)
+ block_i * warp_size
+ lane_id
)
return thread_id return thread_id
def forward_index(i: int, j: int) -> int: def forward_index(i: int, j: int) -> int:
...@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object): ...@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object):
""" """
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, ( block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols # the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, ( warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_load_layout.map_indices([mma_i, mma_j]) _, local_id = inverse_mma_load_layout.map_indices([mma_i, mma_j])
return ( return (warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id)
warp_i * (warp_cols * local_size_out)
+ warp_j * local_size_out
+ local_id
)
fragment = T.Fragment( fragment = T.Fragment(
shape, shape,
...@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object):
print(f"fragment.index: {fragment.index}") print(f"fragment.index: {fragment.index}")
return fragment return fragment
def make_mma_store_layout( def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
self, local_buf: Buffer
) -> T.Fragment:
""" """
Create a layout function for storing MMA results into a fragment buffer. Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to This layout is used in conjunction with `inverse_mma_store_layout` to
...@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object):
warp_rows, warp_cols = self.warp_rows, self.warp_cols warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE warp_size = self.WARP_SIZE
is_m_first = self.is_m_first is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int: def forward_thread(i: int, j: int) -> int:
""" """
Given the row index `i` and column index `j` in the fragment, Given the row index `i` and column index `j` in the fragment,
...@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object): ...@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object):
mma_i, mma_j = i % micro_size_x, j % micro_size_y mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j]) lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first: if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_rows + warp_i * warp_cols + warp_j thread_id = block_i * (
block_col_warps * warp_cols) + block_j * warp_rows + warp_i * warp_cols + warp_j
else: else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id return thread_id
...@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object): ...@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object):
""" """
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y # the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols # the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, ( block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
j // micro_size_y
) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols # the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, ( warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
j // micro_size_y
) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y # upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j]) _, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
...@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object): ...@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object):
forward_index_fn=forward_index, forward_index_fn=forward_index,
) )
class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter): class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
""" """
To eliminate Python syntax within TIR Macro. To eliminate Python syntax within TIR Macro.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from tvm.script import tir as T from tvm.script import tir as T
def alloc_shared(shape, dtype, scope="shared.dyn"): def alloc_shared(shape, dtype, scope="shared.dyn"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
......
...@@ -6,11 +6,10 @@ from typing import Union, List, Optional ...@@ -6,11 +6,10 @@ from typing import Union, List, Optional
from tvm import tir from tvm import tir
from tvm.script import tir as T from tvm.script import tir as T
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr): def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
access_type = {"r": 1, "w": 2, "rw": 3}[access_type] access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin( return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args)
"handle", tir.op.Op.get("tl.region"), buffer, access_type, *args
)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str): def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
...@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str): ...@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
return region(T.BufferLoad(buffer, mins), access_type, *extents) return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region( def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]):
load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]
):
return region(load, access_type, *extents) return region(load, access_type, *extents)
def buffer_region_to_tile_region( def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str):
buffer_region: tir.BufferRegion, access_type: str
):
mins = [x.min for x in buffer_region.region] mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region] extents = [x.extent for x in buffer_region.region]
return region( return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents)
T.BufferLoad(buffer_region.buffer, mins), access_type, *extents
)
def copy( def copy(
...@@ -71,9 +64,7 @@ def copy( ...@@ -71,9 +64,7 @@ def copy(
src = _to_region(src, "r") src = _to_region(src, "r")
dst = _to_region(dst, "w") dst = _to_region(dst, "w")
if coalesced_width is not None: if coalesced_width is not None:
return tir.call_intrin( return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width)
"handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width
)
else: else:
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst) return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst)
......
...@@ -10,12 +10,8 @@ def atomic_add(dst, value): ...@@ -10,12 +10,8 @@ def atomic_add(dst, value):
def atomic_addx2(dst, value): def atomic_addx2(dst, value):
return T.call_extern( return T.call_extern("handle", "atomicAddx2", T.address_of(dst), T.address_of(value))
"handle", "atomicAddx2", T.address_of(dst), T.address_of(value)
)
def dp4a(A, B, C): def dp4a(A, B, C):
return T.call_extern( return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
"handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment