Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -32,29 +32,32 @@ namespace tl {
using namespace tir;
class BufferIndiceSimplify : public StmtExprMutator {
public:
BufferIndiceSimplify(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
public:
BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
private:
PrimExpr VisitExpr_(const BufferLoadNode* node) final {
private:
PrimExpr VisitExpr_(const BufferLoadNode *node) final {
auto visited = StmtExprMutator::VisitExpr_(node);
auto n = visited.as<BufferLoad>().value();
auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); });
nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); });
return n;
}
Stmt VisitStmt_(const BufferStoreNode* node) final {
Stmt VisitStmt_(const BufferStoreNode *node) final {
auto visited = StmtExprMutator::VisitStmt_(node);
auto n = visited.as<BufferStore>().value();
auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map([&](const auto& e) { return analyzer_->Simplify(e); });
nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); });
return n;
}
arith::Analyzer* analyzer_;
arith::Analyzer *analyzer_;
};
// Rewrite the parallel loop into a common loop, which is mapped to threads
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout) {
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout) {
ICHECK(loop_layout.defined());
ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim();
......@@ -71,7 +74,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
Map<Var, PrimExpr> vmap;
Stmt body = op;
auto inv_loop = loop_layout->Inverse();
auto indices = inv_loop->Forward(vars.Map([](const Var& v) { return PrimExpr(v); }));
auto indices =
inv_loop->Forward(vars.Map([](const Var &v) { return PrimExpr(v); }));
for (int i = 0; i < old_loop_depth; i++) {
ICHECK(body.as<For>().defined());
For loop = body.as<For>().value();
......@@ -82,8 +86,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
// substitute and re-construct the serial loop
body = Substitute(body, vmap);
for (int i = new_loop_depth - 1; i >= 0; i--) {
body =
For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], ForKind::kSerial, body);
body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
ForKind::kSerial, body);
analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
}
......@@ -95,11 +99,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment lo
}
class LoopPramaUnroller : public StmtExprMutator {
public:
public:
LoopPramaUnroller() = default;
private:
Stmt VisitStmt_(const ForNode* node) final {
private:
Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kSerial) {
For new_for = GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite();
......@@ -112,7 +116,7 @@ class LoopPramaUnroller : public StmtExprMutator {
};
class LoopPartitioner : public StmtExprVisitor {
public:
public:
LoopPartitioner() = default;
Fragment Partition(For op, int num_thread, int vectorize_size) {
......@@ -129,17 +133,18 @@ class LoopPartitioner : public StmtExprVisitor {
ICHECK(loop_size_full % vectorize_size == 0);
PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx =
FloorDiv(access_idx, num_thread) * vectorize_size + FloorMod(flattened, vectorize_size);
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {});
}
private:
void VisitStmt_(const ForNode* node) final {
private:
void VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) {
body_ = node->body;
loop_vars_.push_back(IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
IterVarType::kDataPar));
loop_vars_.push_back(
IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
IterVarType::kDataPar));
}
StmtExprVisitor::VisitStmt_(node);
}
......@@ -160,5 +165,5 @@ For LoopPragmaUnroll(For stmt) {
return unrolled;
}
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -36,13 +36,14 @@ namespace tl {
using namespace tir;
For PartitionLoop(For op, Var thread_var, arith::Analyzer* analyzer, Fragment loop_layout);
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment loop_layout);
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size);
For LoopPragmaUnroll(For stmt);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LOOP_PARTITION_H_
#endif // TVM_TL_LOOP_PARTITION_H_
......@@ -30,10 +30,10 @@
#include <numeric>
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
namespace tvm {
......@@ -48,10 +48,10 @@ struct VectorizePlanResult {
};
class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
public:
public:
VectorizePlanner() = default;
int Plan(const For& node) {
int Plan(const For &node) {
this->operator()(node);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
......@@ -62,18 +62,19 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
PrimExpr GetCondition() { return condition_; }
private:
void VisitStmt_(const ForNode* node) final {
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode* node) final {
void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
if (node->buffer->shape.size() == 1 && node->buffer->shape[0].as<IntImmNode>()->value == 1) {
if (node->buffer->shape.size() == 1 &&
node->buffer->shape[0].as<IntImmNode>()->value == 1) {
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
......@@ -82,7 +83,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void VisitStmt_(const BufferStoreNode* node) final {
void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
......@@ -90,12 +91,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitStmt_(const IfThenElseNode* node) final {
void VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const CallNode* node) final {
void VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) {
......@@ -105,16 +106,18 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void CheckConditionVectorized(const PrimExpr& cond) {
void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer& buffer) {
if (!inner_for_) return;
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr) return;
if (!extent_ptr)
return;
const DataType& access_type = buffer->dtype;
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = 128 / access_type.bits();
// so we should disable this GCD optimization
......@@ -122,7 +125,8 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block conditionally tail vectorize
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
......@@ -142,8 +146,9 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, inner_for_->extent,
vector_size_, &analyzer_)) {
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
......@@ -156,7 +161,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
static const int vector_load_bits_max_ = 128;
const ForNode* inner_for_;
const ForNode *inner_for_;
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
int vector_size_ = 128;
......@@ -166,12 +171,12 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
};
class VectorizeDynamicCallRemover : public StmtExprMutator {
public:
public:
VectorizeDynamicCallRemover(Var inner_var, int vector_size)
: inner_var_(inner_var), vector_size_(vector_size) {}
private:
PrimExpr VisitExpr_(const CallNode* op) final {
private:
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) {
PrimExpr cond = this->VisitExpr(op->args[0]);
Map<Var, PrimExpr> vmap;
......@@ -191,15 +196,16 @@ class VectorizeDynamicCallRemover : public StmtExprMutator {
};
class VectorizeRewriter : public StmtExprMutator {
public:
public:
VectorizeRewriter(VectorizePlanResult plan)
: vector_size_(plan.vector_size), condition_(plan.condition), dynamic_(plan.dynamic) {}
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic) {}
private:
Stmt VisitStmt_(const ForNode* node) final {
private:
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { // rewrite the innermost loop
if (inner_for_ == node) { // rewrite the innermost loop
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
......@@ -208,7 +214,7 @@ class VectorizeRewriter : public StmtExprMutator {
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
if (!dynamic_) { // check dynamic shape
if (!dynamic_) { // check dynamic shape
if (extent == vector_size_) {
fnode.CopyOnWrite()->kind = ForKind::kVectorized;
return fnode;
......@@ -219,8 +225,8 @@ class VectorizeRewriter : public StmtExprMutator {
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
......@@ -237,11 +243,13 @@ class VectorizeRewriter : public StmtExprMutator {
VectorizeDynamicCallRemover remover(inner_var, vector_size_);
body = remover(body);
For vectorize_for = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
For serial_for =
For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, fnode->thread_binding,
fnode->annotations, fnode->span);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
} else {
......@@ -249,15 +257,15 @@ class VectorizeRewriter : public StmtExprMutator {
}
}
const ForNode* inner_for_;
const ForNode *inner_for_;
const int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
};
int GetVectorizeSize(const For& loop) { return VectorizePlanner().Plan(loop); }
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
VectorizePlanResult GetVectorizePlanResult(const For& loop) {
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
VectorizePlanner planner;
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
......@@ -265,16 +273,19 @@ VectorizePlanResult GetVectorizePlanResult(const For& loop) {
return {vector_size, dynamic, condition};
}
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size,
arith::Analyzer* analyzer) {
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1) return true;
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), 0)) return false;
if (target_vectorized_size == 1)
return true;
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;
Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
PrimExpr expr_transformed =
analyzer->Simplify(Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
......@@ -290,16 +301,17 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int targ
}
}
For VectorizeLoop(const For& loop, int vectorize_hint) {
For VectorizeLoop(const For &loop, int vectorize_hint) {
VectorizePlanResult res{128, false, 0};
if (vectorize_hint <= 0) {
res = GetVectorizePlanResult(loop);
vectorize_hint = res.vector_size;
}
if (vectorize_hint == 1) return loop;
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(res);
return Downcast<For>(rewriter(loop));
}
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -35,13 +35,13 @@ namespace tl {
using namespace tir;
int GetVectorizeSize(const For& loop);
For VectorizeLoop(const For& loop, int vectorize_hint = -1);
int GetVectorizeSize(const For &loop);
For VectorizeLoop(const For &loop, int vectorize_hint = -1);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, int target_vectorized_size,
arith::Analyzer* analyzer);
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
int target_vectorized_size, arith::Analyzer *analyzer);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LOOP_VECTORIZE_H_
#endif // TVM_TL_LOOP_VECTORIZE_H_
......@@ -37,15 +37,15 @@ namespace tl {
using namespace tir;
class LowerHopperIntrin : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc& f) {
PrimFuncNode* fptr = f.CopyOnWrite();
public:
static PrimFunc Substitute(PrimFunc &f) {
PrimFuncNode *fptr = f.CopyOnWrite();
LowerHopperIntrin substituter;
fptr->body = substituter.VisitStmt(f->body);
for (auto [call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc =
Call(DataType::Handle(), builtin::tvm_stack_alloca(), {StringImm("arg_value"), 16});
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16});
Array<PrimExpr> init_desc_args;
if (call->op.same_as(CreateTMADescriptorOp())) {
init_desc_args.push_back(StringImm(tvm_tensormap_create_tiled));
......@@ -55,15 +55,19 @@ class LowerHopperIntrin : public StmtExprMutator {
CHECK(0) << call->op;
}
init_desc_args.push_back(var);
init_desc_args.insert(init_desc_args.end(), call->args.begin(), call->args.end());
Call init_desc = Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body = LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
init_desc_args.insert(init_desc_args.end(), call->args.begin(),
call->args.end());
Call init_desc =
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body =
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
}
return f;
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
// Insert the prefetch TMA descriptor statement TO the beginning of the kernel
Stmt VisitStmt_(const AttrStmtNode *op) final {
// Insert the prefetch TMA descriptor statement TO the beginning of the
// kernel
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......@@ -73,18 +77,22 @@ class LowerHopperIntrin : public StmtExprMutator {
} else {
Array<Stmt> stmt_seq;
if (!init_mbarrier_calls_.empty()) {
auto alloc_mbarrier = Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
{static_cast<int>(init_mbarrier_calls_.size())}));
auto alloc_mbarrier =
Evaluate(Call(DataType::Handle(), builtin::create_barriers(),
{static_cast<int>(init_mbarrier_calls_.size())}));
stmt_seq.push_back(alloc_mbarrier);
}
auto stmts = prefetch_calls_;
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(), init_mbarrier_calls_.end());
auto init_stmt = IfThenElse(EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
init_mbarrier_calls_.end());
auto init_stmt = IfThenElse(
EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(init_stmt);
if (!init_mbarrier_calls_.empty()) {
Stmt mem_sync = Evaluate(
Call(DataType::Handle(), builtin::tvm_storage_sync(), {StringImm("shared")}));
Stmt mem_sync =
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")}));
stmt_seq.push_back(mem_sync);
}
stmt_seq.push_back(body);
......@@ -98,7 +106,7 @@ class LowerHopperIntrin : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const CallNode* call) final {
PrimExpr VisitExpr_(const CallNode *call) final {
if (call->op.same_as(CreateTMADescriptorOp()) ||
call->op.same_as(CreateTMAIm2ColDescriptorOp())) {
Var var;
......@@ -107,10 +115,12 @@ class LowerHopperIntrin : public StmtExprMutator {
var = iter->second;
} else {
String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc", PointerType(PrimType(cuTensorMapType()), "grid_constant"));
var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;
prefetch_calls_.push_back(Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
prefetch_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
}
return var;
} else if (call->op.same_as(CreateListofMBarrierOp())) {
......@@ -118,24 +128,25 @@ class LowerHopperIntrin : public StmtExprMutator {
int num_barriers = static_cast<int>(call->args.size());
for (int i = 0; i < num_barriers; i++) {
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {i});
init_mbarrier_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[i]})));
init_mbarrier_calls_.push_back(Evaluate(
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[i]})));
}
return 0;
} else if (call->op.same_as(SyncThreadsPartialOp())) {
int barrier_id = init_mbarrier_calls_.size();
PrimExpr mbarrier = Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
init_mbarrier_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]})));
PrimExpr mbarrier =
Call(DataType::Handle(), GetMBarrierOp(), {barrier_id});
init_mbarrier_calls_.push_back(Evaluate(
Call(DataType::Handle(), builtin::ptx_init_barrier_thread_count(),
{mbarrier, call->args[0]})));
return Call(DataType::Handle(), SyncThreadsPartialOp(), {mbarrier});
} else {
return StmtExprMutator::VisitExpr_(call);
}
}
private:
private:
Array<Stmt> prefetch_calls_;
Array<Stmt> init_mbarrier_calls_;
std::unordered_map<Call, Var, StructuralHash, ExprDeepEqual> desc_map_;
......@@ -154,5 +165,5 @@ tvm::transform::Pass LowerHopperIntrin() {
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin")
.set_body_typed(LowerHopperIntrin);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -27,10 +27,10 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "loop_partition.h"
namespace tvm {
......@@ -38,8 +38,9 @@ namespace tl {
using namespace tir;
static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout) {
const auto *ptr_type =
TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
Type new_type;
// convert fragments to normal local buffer
if (ptr_type->storage_scope == "local.fragment") {
......@@ -53,32 +54,33 @@ static Buffer makeBufferWithLayout(const Buffer& buffer, const Layout& layout) {
} else {
new_var = Var(buffer->data->name_hint, new_type);
}
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {}, buffer->elem_offset,
buffer->name, buffer->data_alignment, buffer->offset_factor, buffer->buffer_type);
return Buffer(new_var, buffer->dtype, layout->OutputShape(), {},
buffer->elem_offset, buffer->name, buffer->data_alignment,
buffer->offset_factor, buffer->buffer_type);
}
class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
public:
public:
static PrimFunc Substitute(PrimFunc f) {
arith::Analyzer analyzer;
LowerTileOpPass substituter(&analyzer);
// Trace the buffer map for tvm_access_ptr
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
for (const auto& [_, buffer] : f->buffer_map) {
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value();
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
return f;
}
private:
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode* op) final {
Stmt VisitStmt_(const BlockNode *op) final {
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
......@@ -91,7 +93,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
}
Map<Var, Layout> vmap;
if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.at(attr::kLayoutMap).as<Map<Buffer, Layout>>().value();
auto layout_map = op->annotations.at(attr::kLayoutMap)
.as<Map<Buffer, Layout>>()
.value();
for (auto [buffer, layout] : layout_map) {
buffer_remap_.Set(buffer, makeBufferWithLayout(buffer, layout));
layout_map_.Set(buffer, layout);
......@@ -105,7 +109,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
}
}
for (const auto& buffer : workspaces_) block_ptr->alloc_buffers.push_back(buffer);
for (const auto &buffer : workspaces_)
block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear();
block_ptr->annotations.erase(attr::kLayoutMap);
return block;
......@@ -113,18 +118,19 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
int CheckAndGetBufferRowSize(Buffer buffer) {
CHECK(buffer->shape.size() >= 2)
<< "The dimension of Buffer \"" << buffer->name << "\" with shape " << buffer->shape
<< " should be at least 2";
<< "The dimension of Buffer \"" << buffer->name << "\" with shape "
<< buffer->shape << " should be at least 2";
auto dim = buffer->shape.size();
auto buffer_row_size = buffer->shape[dim - 1].as<IntImmNode>()->value;
return buffer_row_size;
}
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr, Optional<PrimExpr> offset = NullOpt,
PrimExpr HandleAccessPtrAndOffset(PrimExpr access_ptr,
Optional<PrimExpr> offset = NullOpt,
DataType dtype = DataType::Int(32)) {
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and accumulate it to
// smem_offset
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
<< "Invalid access ptr for permuted layout: " << access_ptr;
auto access_ptr_call = Downcast<Call>(access_ptr);
......@@ -136,8 +142,10 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
Array<PrimExpr> shape = load->buffer->shape;
CHECK_EQ(indices.size(), shape.size())
<< "Indices size and shape size must match for general N-dimensional buffer "
<< "but got indices size: " << indices.size() << " and shape size: " << shape.size();
<< "Indices size and shape size must match for general N-dimensional "
"buffer "
<< "but got indices size: " << indices.size()
<< " and shape size: " << shape.size();
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
......@@ -147,13 +155,16 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
stride *= shape[i];
}
PrimExpr smem_offset = elem_offset + (offset.defined() ? offset.value() : 0);
PrimExpr smem_offset =
elem_offset + (offset.defined() ? offset.value() : 0);
auto new_buffer = buffer_remap_[load->buffer];
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(load->buffer->data));
auto buffer_map_iter =
buffer_map_.find(Downcast<Var>(load->buffer->data));
CHECK(buffer_map_iter != buffer_map_.end())
<< "The buffer corresponding to data Var " << access_ptr_call->args[0] << " is not found";
<< "The buffer corresponding to data Var " << access_ptr_call->args[0]
<< " is not found";
int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
(void)buffer_row_size;
......@@ -163,11 +174,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
PrimExpr remaining_offset = smem_offset;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
multi_dim_indices.insert(multi_dim_indices.begin(), floormod(remaining_offset, shape[i]));
multi_dim_indices.insert(multi_dim_indices.begin(),
floormod(remaining_offset, shape[i]));
remaining_offset = floordiv(remaining_offset, shape[i]);
}
auto forward_indices = layout_map_[load->buffer]->Forward(multi_dim_indices);
auto forward_indices =
layout_map_[load->buffer]->Forward(multi_dim_indices);
PrimExpr new_offset = 0;
PrimExpr stride_offset = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
......@@ -191,8 +204,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return access_ptr_call;
}
PrimExpr VisitExpr_(const tir::CallNode* op) final {
if (!op->op.same_as(builtin::ptx_ldmatrix()) && !op->op.same_as(builtin::mma_store())) {
PrimExpr VisitExpr_(const tir::CallNode *op) final {
if (!op->op.same_as(builtin::ptx_ldmatrix()) &&
!op->op.same_as(builtin::mma_store())) {
return Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
} else {
is_ptx_ = true;
......@@ -212,15 +226,18 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);
if (buffer_remap_.count(load->buffer)) {
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
}
} else if (call->op.same_as(builtin::mma_store())) {
// because we will directly store result to Buffer instead of calling mma_store now
// because we will directly store result to Buffer instead of calling
// mma_store now
auto access_ptr = call->args[2];
auto new_access_ptr = HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, NullOpt, call->dtype);
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
} else {
......@@ -230,7 +247,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return call;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (is_ptx_) {
return load;
......@@ -243,7 +260,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return load;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
......@@ -253,36 +270,40 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer {
return store;
}
PrimExpr VisitExpr_(const VarNode* op) final {
PrimExpr VisitExpr_(const VarNode *op) final {
auto var = Downcast<Var>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (buffer_data_to_buffer_.count(var)) {
auto buffer = buffer_data_to_buffer_[var];
if (buffer_remap_.count(buffer)) return buffer_remap_[buffer]->data;
if (buffer_remap_.count(buffer))
return buffer_remap_[buffer]->data;
}
return var;
}
Stmt VisitStmt_(const EvaluateNode* op) final {
const CallNode* call = op->value.as<CallNode>();
Stmt VisitStmt_(const EvaluateNode *op) final {
const CallNode *call = op->value.as<CallNode>();
// Do not analysis the call node to the global function.
if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
if (tile_op == nullptr) return IRMutatorWithAnalyzer::VisitStmt_(op);
if (tile_op == nullptr)
return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace = decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
auto workspace =
decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace);
return workspace.access_ptr(2); // write
return workspace.access_ptr(2); // write
};
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_block_size_, thread_var_, callback, layout_map_, buffer_remap_},
analyzer_);
auto lowered =
tile_op->Lower(LowerArgs{target_, thread_block_size_, thread_var_,
callback, layout_map_, buffer_remap_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
......@@ -321,7 +342,7 @@ tvm::transform::Pass LowerTileOp() {
}
TVM_REGISTER_GLOBAL("tl.transform.LowerTileOp").set_body_typed(LowerTileOp);
} // namespace transform
} // namespace transform
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -38,22 +38,23 @@ using namespace tir;
enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker_ : public StmtVisitor {
public:
public:
WarpSpecializedRoleMarker_(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const {
Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt);
ICHECK(it != map_.end());
return it->second;
}
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); }
Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
......@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole(op, role);
}
void VisitStmt_(const BufferStoreNode* op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) {
SetRole(op, Role::kConsumer);
return;
......@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
break;
}
}
if (role == Role::kProducer) has_simt_copy_ = true;
if (role == Role::kProducer)
has_simt_copy_ = true;
SetRole(op, role);
}
void VisitStmt_(const SeqStmtNode* op) final {
void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) {
......@@ -96,48 +99,48 @@ class WarpSpecializedRoleMarker_ : public StmtVisitor {
SetRole(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth;
if (role != role_else)
role = Role::kBoth;
}
SetRole(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
template <class NodeType> void HandleBodyStmt(const NodeType *op) {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; }
private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; }
private:
void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_;
std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
};
class MultiVersionBufferRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc& f) {
public:
static PrimFunc Substitute(PrimFunc &f) {
auto rewriter = MultiVersionBufferRewriter();
rewriter.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : rewriter.buffer_lca_) {
......@@ -148,40 +151,45 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return f;
}
private:
private:
MultiVersionBufferRewriter() = default;
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt, Array<Buffer> scoped_buffers) {
Array<Buffer> GetVersionedBuffers(Array<Stmt> seq_stmt,
Array<Buffer> scoped_buffers) {
std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
for (auto stmt : seq_stmt) {
marker(stmt);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"", /*body*/ stmt);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
reads.push_back(std::move(access[0]));
writes.push_back(std::move(access[1]));
roles.push_back(marker.GetRole(stmt));
}
std::unordered_set<const BufferNode*> consumer_used, producer_used;
std::unordered_set<const BufferNode *> consumer_used, producer_used;
for (size_t i = 0; i < seq_stmt.size(); i++) {
if (roles[i] == Role::kProducer) {
for (BufferRegion br : writes[i]) producer_used.insert(br->buffer.get());
for (BufferRegion br : writes[i])
producer_used.insert(br->buffer.get());
} else {
for (BufferRegion br : reads[i]) consumer_used.insert(br->buffer.get());
for (BufferRegion br : reads[i])
consumer_used.insert(br->buffer.get());
}
}
Array<Buffer> versioned_buffers;
for (Buffer buffer : scoped_buffers) {
if (consumer_used.count(buffer.get()) && producer_used.count(buffer.get())) {
if (consumer_used.count(buffer.get()) &&
producer_used.count(buffer.get())) {
versioned_buffers.push_back(buffer);
}
}
return versioned_buffers;
}
static Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (new_buffer->strides.size()) {
......@@ -192,8 +200,9 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return Buffer(new_buffer);
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
Block block = block_realize->block;
Array<Buffer> alloc_buffers;
for (auto buffer : block->alloc_buffers) {
......@@ -209,24 +218,27 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return block_realize;
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(op);
if (!num_stages_anno.defined())
return StmtExprMutator::VisitStmt_(op);
ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
const SeqStmtNode* pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< op->body->GetTypeKey();
const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< op->body->GetTypeKey();
Array<Buffer> scoped_buffers = {};
for (auto [buffer, stmt] : buffer_lca_) {
if (stmt.defined() && stmt.value().get() == op) scoped_buffers.push_back(buffer);
if (stmt.defined() && stmt.value().get() == op)
scoped_buffers.push_back(buffer);
}
Array<Buffer> versioned_buffers = GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
Array<Buffer> versioned_buffers =
GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
for (auto buffer : versioned_buffers) {
Var buffer_var = buffer->data;
......@@ -239,33 +251,33 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return for_node;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) {
return std::move(load);
}
const Buffer& new_buffer = (*it).second;
auto* n = load.CopyOnWrite();
const Buffer &new_buffer = (*it).second;
auto *n = load.CopyOnWrite();
n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_);
return std::move(load);
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) {
return std::move(store);
}
const Buffer& new_buffer = (*it).second;
auto* n = store.CopyOnWrite();
const Buffer &new_buffer = (*it).second;
auto *n = store.CopyOnWrite();
n->buffer = new_buffer;
n->indices.insert(n->indices.begin(), version_index_);
return std::move(store);
}
PrimExpr VisitExpr_(const CallNode* op) final {
PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1});
......@@ -273,20 +285,23 @@ class MultiVersionBufferRewriter : public StmtExprMutator {
return call;
}
PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr>& input) {
return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
auto buffer_var = Downcast<Var>(call->args[i]);
if (!buffer_data_to_buffer_.count(buffer_var)) continue;
const Buffer& buffer = buffer_data_to_buffer_[buffer_var];
if (!buffer_data_to_buffer_.count(buffer_var))
continue;
const Buffer &buffer = buffer_data_to_buffer_[buffer_var];
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer& new_buffer = (*it).second;
const PrimExpr& old_index = call->args[i + 1];
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
......@@ -318,5 +333,5 @@ tvm::transform::Pass MultiVersionBuffer() {
TVM_REGISTER_GLOBAL("tl.transform.MultiVersionBuffer")
.set_body_typed(MultiVersionBuffer);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -56,22 +56,23 @@ bool MayConflict(Region region1, Region region2) {
return true;
}
} // namespace
} // namespace
class PipelinePlanner : public StmtExprMutator {
public:
static Stmt Substitute(const PrimFunc& f) {
public:
static Stmt Substitute(const PrimFunc &f) {
PipelinePlanner substituter;
for (const auto& [_, buffer] : f->buffer_map) {
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "Pipeline_Planning: Require the target attribute";
ICHECK(target.defined())
<< "Pipeline_Planning: Require the target attribute";
substituter.target_ = target.value();
return substituter.VisitStmt(f->body);
}
private:
private:
PipelinePlanner() = default;
struct PipelineStageInfo {
......@@ -83,8 +84,10 @@ class PipelinePlanner : public StmtExprMutator {
};
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt);
Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ stmt);
Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
PipelineStageInfo pinfo;
pinfo.reads = std::move(access[0]);
pinfo.writes = std::move(access[1]);
......@@ -93,22 +96,25 @@ class PipelinePlanner : public StmtExprMutator {
// copy stage should only have one reads and one writes
if (pinfo.reads.size() == 1 && pinfo.writes.size() == 1) {
for (auto region : pinfo.reads)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
for (auto region : pinfo.writes)
if (region->buffer.scope() == "global") pinfo.copy_stage = true;
if (region->buffer.scope() == "global")
pinfo.copy_stage = true;
}
return std::move(pinfo);
}
Stmt VisitStmt_(const ForNode* loop) final {
Stmt VisitStmt_(const ForNode *loop) final {
auto num_stages_anno = loop->annotations.Get("num_stages");
if (!num_stages_anno.defined()) return StmtExprMutator::VisitStmt_(loop);
if (!num_stages_anno.defined())
return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno.as<IntImmNode>()->value;
Stmt pipeline_body{nullptr};
if (const auto* realize = loop->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
for (const auto& buffer : block->alloc_buffers) {
if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
......@@ -116,10 +122,10 @@ class PipelinePlanner : public StmtExprMutator {
} else {
pipeline_body = loop->body;
}
const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline should be SeqStmt, got "
<< loop->body->GetTypeKey();
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< loop->body->GetTypeKey();
CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial);
......@@ -130,21 +136,28 @@ class PipelinePlanner : public StmtExprMutator {
}
// analysis use-def chain
for (auto& pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1; i < static_cast<int>(pipeline_body_seq->size()); i++) {
if (!pinfo.copy_stage) continue;
for (const BufferRegion& read : pipeline_stage_infos[i].reads) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) {
return r->buffer == read->buffer && MayConflict(r->region, read->region);
}) != pinfo.writes.end()) {
for (auto &pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1;
i < static_cast<int>(pipeline_body_seq->size()); i++) {
if (!pinfo.copy_stage)
continue;
for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
[&](const BufferRegion &r) {
return r->buffer == read->buffer &&
MayConflict(r->region, read->region);
}) != pinfo.writes.end()) {
pinfo.last_use_stage = std::max(pinfo.last_use_stage, i);
}
}
for (const BufferRegion& write : pipeline_stage_infos[i].writes) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), [&](const BufferRegion& r) {
return r->buffer == write->buffer && MayConflict(r->region, write->region);
}) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer region in the pipeline "
for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
[&](const BufferRegion &r) {
return r->buffer == write->buffer &&
MayConflict(r->region, write->region);
}) != pinfo.writes.end()) {
CHECK(false) << "Can't handle multiple write on overlap buffer "
"region in the pipeline "
"planning pass: "
<< pipeline_body_seq->seq[pinfo.original_order];
}
......@@ -154,28 +167,32 @@ class PipelinePlanner : public StmtExprMutator {
// Making stages and orders
int order_idx = 0;
for (auto& pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage && pinfo.last_use_stage != -1) continue;
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage && pinfo.last_use_stage != -1)
continue;
pinfo.order = order_idx++;
pinfo.stage = num_stages;
for (auto& pinfo_1 : pipeline_stage_infos) {
if (pinfo_1.copy_stage && pinfo_1.last_use_stage == pinfo.original_order) {
for (auto &pinfo_1 : pipeline_stage_infos) {
if (pinfo_1.copy_stage &&
pinfo_1.last_use_stage == pinfo.original_order) {
pinfo_1.order = order_idx++;
pinfo_1.stage = 0;
}
}
}
ICHECK(size_t(order_idx) == pipeline_stage_infos.size()) <<
"The number of stages should be equal to the number of pipeline stages. " <<
"Got " << order_idx << " stages and " << pipeline_stage_infos.size() << " pipeline stages.";
ICHECK(size_t(order_idx) == pipeline_stage_infos.size())
<< "The number of stages should be equal to the number of pipeline "
"stages. "
<< "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
<< " pipeline stages.";
// if all the copy is at the end of the order, we can move these copy to the beginning of the
// order and shrink the stage offset by 1.
// if all the copy is at the end of the order, we can move these copy to the
// beginning of the order and shrink the stage offset by 1.
int copy_stage_at_end = [&]() {
int copy_stage_cnt = 0;
int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0;
for (auto& pinfo : pipeline_stage_infos) {
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage) {
copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order);
......@@ -183,19 +200,22 @@ class PipelinePlanner : public StmtExprMutator {
non_copy_order_max = std::max(non_copy_order_max, pinfo.order);
}
}
if (copy_order_min > non_copy_order_max) return copy_stage_cnt;
if (copy_order_min > non_copy_order_max)
return copy_stage_cnt;
return -1;
}();
if (copy_stage_at_end > 0 && num_stages >= 2) {
for (auto& pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order = (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage) pinfo.stage--;
for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order =
(pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage)
pinfo.stage--;
}
}
// Finally, make the pipeline annotation
Map<String, ObjectRef> annotations;
for (const auto& [key, value] : loop->annotations) {
for (const auto &[key, value] : loop->annotations) {
if (key != "num_stages") {
annotations.Set(key, value);
}
......@@ -204,7 +224,7 @@ class PipelinePlanner : public StmtExprMutator {
std::vector<Integer> orders, stages;
orders.reserve(pipeline_stage_infos.size());
stages.reserve(pipeline_stage_infos.size());
for (auto& pinfo : pipeline_stage_infos) {
for (auto &pinfo : pipeline_stage_infos) {
orders.push_back(pinfo.order);
stages.push_back(pinfo.stage);
}
......@@ -212,18 +232,19 @@ class PipelinePlanner : public StmtExprMutator {
annotations.Set(tir::attr::software_pipeline_stage, Array<Integer>(stages));
annotations.Set(tir::attr::software_pipeline_order, Array<Integer>(orders));
if (TargetHasAsyncCopy(target_))
annotations.Set(tir::attr::software_pipeline_async_stages, Array<Integer>{0});
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
return For(loop->loop_var, loop->min, loop->extent, loop->kind, loop->body,
loop->thread_binding, annotations);
}
Stmt VisitStmt_(const BlockNode* op) final {
for (const auto& buffer : op->alloc_buffers) {
Stmt VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto& buffer : op->alloc_buffers) {
for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data);
}
return std::move(block);
......@@ -236,14 +257,15 @@ class PipelinePlanner : public StmtExprMutator {
tvm::transform::Pass PipelinePlanning() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
PrimFuncNode* fptr = f.CopyOnWrite();
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = PipelinePlanner::Substitute(f);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning").set_body_typed(PipelinePlanning);
TVM_REGISTER_GLOBAL("tl.transform.PipelinePlanning")
.set_body_typed(PipelinePlanning);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -6,15 +6,15 @@
* \brief Remove useless parameters of TL PrimFunc.
*/
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/utils.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include "arith/ir_mutator_with_analyzer.h"
#include "tir/analysis/var_use_def_analysis.h"
#include "tir/analysis/control_flow_graph.h"
#include "tir/analysis/var_use_def_analysis.h"
namespace tvm {
namespace tl {
......@@ -31,19 +31,19 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
TVM_DECLARE_ATTRS(SimplifyConfigNode, "tl.transform.SimplifyConfig") {
TVM_ATTR_FIELD(transitively_prove_inequalities)
.describe(
"If true, simplify conditionals with transitive combinations of scoped constraints")
.describe("If true, simplify conditionals with transitive combinations "
"of scoped constraints")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
.describe(
"If true, known buffer values are propagated and used to statically prove conditionals")
.describe("If true, known buffer values are propagated and used to "
"statically prove conditionals")
.set_default(false);
TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
.describe(
"If true, known buffer values are propagated and used to replace BufferLoad wherever "
"possible")
.describe("If true, known buffer values are propagated and used to "
"replace BufferLoad wherever "
"possible")
.set_default(false);
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
......@@ -51,102 +51,103 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
.set_default(false);
TVM_ATTR_FIELD(apply_constraints_to_boolean_branches)
.describe(
"If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.describe("If true, simplify each branch of AND/OR "
"under a constraints provided by the other branch")
.set_default(false);
}
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
if (transitively_prove_inequalities) {
flags =
RewriteSimplifier::Extension(flags | RewriteSimplifier::kTransitivelyProveInequalities);
flags = RewriteSimplifier::Extension(
flags | RewriteSimplifier::kTransitivelyProveInequalities);
}
if (convert_boolean_to_and_of_ors) {
flags = RewriteSimplifier::Extension(flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
flags = RewriteSimplifier::Extension(
flags | RewriteSimplifier::kConvertBooleanToAndOfOrs);
}
if (apply_constraints_to_boolean_branches) {
flags = RewriteSimplifier::Extension(flags |
RewriteSimplifier::kApplyConstraintsToBooleanBranches);
flags = RewriteSimplifier::Extension(
flags | RewriteSimplifier::kApplyConstraintsToBooleanBranches);
}
return flags;
}
};
std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
std::unordered_set<const BufferNode *>
CollectUsedBuffers(const PrimFunc &func) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;
Visitor(PrimFunc func) : func(func) {}
void VisitExpr_(const CallNode* op) override {
for (const auto& arg: op->args) {
for (const auto& it: func->buffer_map) {
if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
used_in_buffer_def_.insert(it.second.get());
}
}
void VisitExpr_(const CallNode *op) override {
for (const auto &arg : op->args) {
for (const auto &it : func->buffer_map) {
if (Downcast<PrimExpr>(it.second.get()->data).same_as(arg)) {
used_in_buffer_def_.insert(it.second.get());
}
}
StmtExprVisitor::VisitExpr_(op);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const BufferLoadNode* op) override {
void VisitExpr_(const BufferLoadNode *op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
void VisitStmt_(const BufferStoreNode *op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BlockNode* op) override {
for (const auto& buffer: op->alloc_buffers) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
void VisitStmt_(const BlockNode *op) override {
for (const auto &buffer : op->alloc_buffers) {
for (const auto &it : func->buffer_map) {
if (it.second.get()->data.same_as(buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
for (const auto& buffer: op->reads) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
for (const auto &buffer : op->reads) {
for (const auto &it : func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
for (const auto& buffer: op->writes) {
for (const auto& it: func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
}
for (const auto &buffer : op->writes) {
for (const auto &it : func->buffer_map) {
if (it.second.get()->data.same_as(buffer->buffer.get()->data)) {
used_in_buffer_def_.insert(it.second.get());
}
}
StmtExprVisitor::VisitStmt_(op);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitBuffer(const Buffer& buf) {
void VisitBuffer(const Buffer &buf) {
// Collect buffers that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
for (const auto &dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
for (const auto &dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
for (const auto& buffer : usage.buffer_use_count_) {
for (const auto &buffer : usage.buffer_use_count_) {
if (buffer.second >= 1) {
used_in_buffer_def_.insert(buffer.first);
used_in_buffer_def_.insert(buffer.first);
}
}
for (const auto& buffer : usage.undefined_buffers_) {
for (const auto &buffer : usage.undefined_buffers_) {
used_in_buffer_def_.insert(buffer.get());
}
}
PrimFunc func;
std::unordered_set<const BufferNode*> used_in_buffer_def_;
std::unordered_set<const BufferNode *> used_in_buffer_def_;
};
Visitor visitor(func);
......@@ -154,41 +155,42 @@ std::unordered_set<const BufferNode*> CollectUsedBuffers(const PrimFunc& func) {
return visitor.used_in_buffer_def_;
}
/* \brief Utility function to collect vars that should be retained. Used in Letstmt Only
*/
std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt& stmt) {
/* \brief Utility function to collect vars that should be retained. Used in
* Letstmt Only
*/
std::unordered_set<const VarNode *>
CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
struct Visitor : StmtExprVisitor {
using StmtExprVisitor::VisitExpr_;
using StmtExprVisitor::VisitStmt_;
void VisitExpr_(const BufferLoadNode* op) override {
void VisitExpr_(const BufferLoadNode *op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) override {
void VisitStmt_(const BufferStoreNode *op) override {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}
void VisitBuffer(const Buffer& buf) {
void VisitBuffer(const Buffer &buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto& dim : buf->shape) {
for (const auto &dim : buf->shape) {
usage(dim);
}
for (const auto& dim : buf->strides) {
for (const auto &dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
// Track for use in LetStmtNode mutator
for (const auto& var : usage.undefined_) {
for (const auto &var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
std::unordered_set<const VarNode*> used_in_buffer_def_;
std::unordered_set<const VarNode *> used_in_buffer_def_;
};
Visitor visitor;
......@@ -197,20 +199,21 @@ std::unordered_set<const VarNode*> CollectVarsUsedInBufferDefinition(const Stmt&
}
class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs, SimplifyConfigNode);
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
};
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer* analyzer,
public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions());
std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional ||
......@@ -218,7 +221,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
touch_pattern = ControlFlowGraph(func->body);
}
std::unordered_set<const VarNode*> used_in_buffer_def =
std::unordered_set<const VarNode *> used_in_buffer_def =
CollectVarsUsedInBufferDefinition(func->body);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
......@@ -232,41 +235,44 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Array<Var> new_params;
Map<Var, Buffer> new_buffer_map;
// Check whether each buffer is used
for (const auto& var: func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) != simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
for (const auto &var : func->params) {
if (func->buffer_map.find(var) != func->buffer_map.end()) {
if (simplifier.used_buffers_.find(func->buffer_map[var].get()) !=
simplifier.used_buffers_.end()) {
new_params.push_back(var);
new_buffer_map.Set(var, func->buffer_map[var]);
} else {
param_updated = true;
}
}
}
// return func;
if (param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, new_buffer_map, func->attrs, func->span);
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else {
return func;
return func;
}
}
private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode*> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer),
config_(config),
touch_pattern_(touch_pattern),
used_in_buffer_def_(used_in_buffer_def) {}
private:
explicit StmtSimplifier(
Analyzer *analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern,
std::unordered_set<const VarNode *> used_in_buffer_def)
: IRMutatorWithAnalyzer(analyzer), config_(config),
touch_pattern_(touch_pattern), used_in_buffer_def_(used_in_buffer_def) {
}
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt;
using Parent::VisitStmt_;
PrimExpr VisitExpr(const PrimExpr& expr) final {
PrimExpr VisitExpr(const PrimExpr &expr) final {
if (config_->propagate_knowns_to_simplify_expressions) {
return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_);
return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(),
analyzer_);
} else {
return analyzer_->Simplify(expr);
}
......@@ -274,7 +280,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
Stmt VisitStmt(const Stmt& stmt) override {
Stmt VisitStmt(const Stmt &stmt) override {
Optional<Stmt> cache = this->current_stmt_;
this->current_stmt_ = stmt;
Stmt output = Parent::VisitStmt(stmt);
......@@ -282,23 +288,28 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return output;
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
With<ConstraintContext> ctx2(analyzer_,
op->loop_var < op->min + op->extent);
return Parent::VisitStmt_(op);
}
bool CanInlineLetStmt(const LetStmtNode* op) {
if (is_const_number(op->value)) return true;
if (op->value.as<VarNode>()) return true;
bool CanInlineLetStmt(const LetStmtNode *op) {
if (is_const_number(op->value))
return true;
if (op->value.as<VarNode>())
return true;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be index).
if (!op->value.dtype().is_int()) return false;
// attempt to inline as much as possible if the value integer type(can be
// index).
if (!op->value.dtype().is_int())
return false;
return SideEffect(op->value) <= CallEffectKind::kPure;
}
Stmt VisitStmt_(const LetStmtNode* op) override {
Stmt VisitStmt_(const LetStmtNode *op) override {
PrimExpr value = this->VisitExpr(op->value);
bool can_inline = CanInlineLetStmt(op);
if (can_inline) {
......@@ -339,7 +350,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
Stmt VisitStmt_(const IfThenElseNode* op) override {
Stmt VisitStmt_(const IfThenElseNode *op) override {
if (Optional<Bool> cond = ProveCondition(op->condition)) {
if (cond.value()->value) {
return this->VisitStmt(op->then_case);
......@@ -353,7 +364,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
PrimExpr VisitExpr_(const CallNode* op) override {
PrimExpr VisitExpr_(const CallNode *op) override {
if (op->op.same_as(builtin::if_then_else())) {
if (Optional<Bool> cond = ProveCondition(op->args[0])) {
if (cond.value()->value) {
......@@ -366,26 +377,27 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const VarNode* op) override {
PrimExpr VisitExpr_(const VarNode *op) override {
used_vars_.insert(op);
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) override {
PrimExpr VisitExpr_(const BufferLoadNode *op) override {
auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer);
used_buffers_.insert(buffer);
}
return Parent::VisitExpr_(op);
}
// eliminate useless stores
Stmt VisitStmt_(const BufferStoreNode* op) override {
Stmt VisitStmt_(const BufferStoreNode *op) override {
BufferStore store = Downcast<BufferStore>(Parent::VisitStmt_(op));
if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
if (const BufferLoadNode *load = store->value.as<BufferLoadNode>()) {
if (load->buffer->data.same_as(store->buffer->data) &&
ArrayDeepEqual(load->indices, store->indices) &&
tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) &&
tir::ExprDeepEqual()(load->buffer->elem_offset,
store->buffer->elem_offset) &&
ArrayDeepEqual(load->buffer->shape, store->buffer->shape) &&
ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) {
return Evaluate(0);
......@@ -393,13 +405,13 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
auto buffer = op->buffer.get();
if (used_buffers_.find(buffer) == used_buffers_.end()) {
used_buffers_.insert(buffer);
used_buffers_.insert(buffer);
}
return std::move(store);
}
private:
bool ArrayDeepEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
private:
bool ArrayDeepEqual(const Array<PrimExpr> &lhs, const Array<PrimExpr> &rhs) {
if (lhs.size() != rhs.size()) {
return false;
}
......@@ -420,11 +432,12 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
condition = Substitute(condition, non_inlined_bindings_);
if (config_->propagate_knowns_to_prove_conditional) {
ICHECK(touch_pattern_.has_value());
condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_);
condition = touch_pattern_->SimplifyInContext(
condition, current_stmt_.value(), analyzer_);
} else {
condition = analyzer_->Simplify(condition);
}
if (const int64_t* as_int = as_const_int(condition)) {
if (const int64_t *as_int = as_const_int(condition)) {
return Bool(*as_int);
} else {
return NullOpt;
......@@ -436,21 +449,20 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
Map<Var, PrimExpr> non_inlined_bindings_;
Optional<Stmt> current_stmt_{NullOpt};
std::unordered_set<const VarNode*> used_in_buffer_def_;
std::unordered_set<const VarNode*> used_vars_;
std::unordered_set<const BufferNode*> used_buffers_;
std::unordered_set<const VarNode *> used_in_buffer_def_;
std::unordered_set<const VarNode *> used_vars_;
std::unordered_set<const BufferNode *> used_buffers_;
};
using namespace tir::transform;
tvm::transform::Pass Simplify() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg);
};
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_REGISTER_GLOBAL("tl.transform.Simplify").set_body_typed(Simplify);
......
......@@ -25,26 +25,28 @@ namespace tl {
using namespace tir;
class ThreadPartialSyncPlanner : public StorageAccessVisitor {
public:
explicit ThreadPartialSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {}
public:
explicit ThreadPartialSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
// The syncs inserted before each statement
std::unordered_set<const Object*> syncs_inserted_;
std::unordered_map<const Object*, int> partial_syncs_inserted_;
std::unordered_set<const Object *> syncs_inserted_;
std::unordered_map<const Object *, int> partial_syncs_inserted_;
protected:
bool Enabled(const VarNode* buf, const StorageScope& scope) const final {
protected:
bool Enabled(const VarNode *buf, const StorageScope &scope) const final {
return in_device_env() && scope == sync_scope_;
}
// Plan the sync
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq, const ForNode* loop) final {
std::vector<AccessEntry> Summarize(std::vector<StmtEntry> seq,
const ForNode *loop) final {
// Redirect all "shared.dyn" buffer access to the same buffer var
// so that the accesses can be planned together.
Var shared_dyn_buf;
for (StmtEntry& entry : seq) {
for (AccessEntry& access : entry.access) {
if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" &&
access.buffer.defined()) {
for (StmtEntry &entry : seq) {
for (AccessEntry &access : entry.access) {
if (access.scope.rank == StorageRank::kShared &&
access.scope.tag == ".dyn" && access.buffer.defined()) {
if (!shared_dyn_buf.defined()) {
shared_dyn_buf = access.buffer;
} else {
......@@ -60,7 +62,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// if it is a loop, rotate two times to consider effect of loop.
// simulation based approach to find dependencies
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
const StmtEntry &s = seq[i];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
......@@ -68,7 +70,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
reads.clear();
writes.clear();
}
for (const AccessEntry& acc : s.access) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, false)) {
sync_before_stmt = true;
......@@ -90,7 +92,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry& acc : s.access) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
reads.push_back(acc);
} else if (acc.type == kWrite) {
......@@ -106,11 +108,13 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
if (loop != nullptr) {
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry& s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) break;
if (reads.empty() && writes.empty()) break;
const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0)
break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false;
for (const AccessEntry& acc : s.access) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
......@@ -141,7 +145,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
esync.type = kSync;
esync.scope = sync_scope_;
for (const StmtEntry& s : seq) {
for (const StmtEntry &s : seq) {
if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) {
tail.clear();
......@@ -150,7 +154,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
++sync_count;
}
for (const AccessEntry& acc : s.access) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kSync) {
if (sync_count != 0) {
tail.clear();
......@@ -170,18 +174,18 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
head.insert(head.end(), tail.begin(), tail.end());
if (loop != nullptr) {
// clear double buffer flag after a loop is finished.
for (AccessEntry& e : head) {
for (AccessEntry &e : head) {
e.double_buffer_write = false;
}
}
return head;
}
private:
private:
// find conflicting entry in vec.
bool FindConflict(const std::vector<AccessEntry>& prev, const AccessEntry& curr,
bool loop_carry) {
for (const AccessEntry& x : prev) {
bool FindConflict(const std::vector<AccessEntry> &prev,
const AccessEntry &curr, bool loop_carry) {
for (const AccessEntry &x : prev) {
if (FindConflict(x, curr, loop_carry)) {
return true;
}
......@@ -189,7 +193,8 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return false;
}
bool FindConflict(const AccessEntry& prev, const AccessEntry& curr, bool loop_carry) {
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
......@@ -202,21 +207,21 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// Even if access has the same index, those indices need to
// depend on the innermost thread id to avoid race condition
bool depends_on_thread_index = true;
const VarNode* thread_index_var = nullptr;
const VarNode *thread_index_var = nullptr;
if (!curr.threads.empty()) {
thread_index_var = curr.threads.back()->var.get();
}
for (size_t i = 0; i < prev.touched.size(); i++) {
const auto& prev_intset = prev.touched[i];
const auto& curr_intset = curr.touched[i];
const auto &prev_intset = prev.touched[i];
const auto &curr_intset = curr.touched[i];
if (prev_intset.IsSinglePoint() && curr_intset.IsSinglePoint()) {
PrimExpr prev_index = prev_intset.PointValue();
PrimExpr curr_index = curr_intset.PointValue();
has_same_index = ExprDeepEqual()(prev_index, curr_index);
if (thread_index_var != nullptr) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) {
auto f_uses_thread_index = [=](const tvm::tir::VarNode *parameter) {
return parameter == thread_index_var;
};
depends_on_thread_index = depends_on_thread_index &&
......@@ -246,7 +251,7 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
return true;
}
void VisitStmt_(const AttrStmtNode* op) final {
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "kWarpSpecializationScope") {
IfThenElse body = Downcast<IfThenElse>(op->body);
auto partitions = Downcast<Array<IntImm>>(op->node);
......@@ -273,27 +278,31 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
}
}
void insert_syncs(const Object* obj) {
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition";
if (syncs_inserted_.count(obj)) return;
void insert_syncs(const Object *obj) {
// ICHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside
// condition";
if (syncs_inserted_.count(obj))
return;
if (num_partial_threads_.defined()) {
syncs_inserted_.insert(obj);
partial_syncs_inserted_[obj] = static_cast<int>(num_partial_threads_.value()->value);
partial_syncs_inserted_[obj] =
static_cast<int>(num_partial_threads_.value()->value);
} else {
syncs_inserted_.insert(obj);
}
}
private:
private:
Optional<IntImm> num_partial_threads_;
// synchronization scope
StorageScope sync_scope_;
};
// There are cases where necessary syncthreads is not inserted by ThreadPartialSyncInserter.
// For example, syncthreads is needed after async_wait_queue in the second loop below,
// but since ThreadPartialSyncInserter is not aware of the asynchronous semantics, it cannot tell
// that the syncthreads is needed there.
// There are cases where necessary syncthreads is not inserted by
// ThreadPartialSyncInserter. For example, syncthreads is needed after
// async_wait_queue in the second loop below, but since
// ThreadPartialSyncInserter is not aware of the asynchronous semantics, it
// cannot tell that the syncthreads is needed there.
//
// // Pipeline prologue
// for i in range(125):
......@@ -307,21 +316,23 @@ class ThreadPartialSyncPlanner : public StorageAccessVisitor {
// async_wait_queue(0, 2 - i):
// local[...] = shared[(i + 125) % 4]
class ThreadPartialSyncInserter : public StmtExprMutator {
public:
ThreadPartialSyncInserter(StorageScope sync_scope, const std::unordered_set<const Object*>& syncs,
std::unordered_map<const Object*, int> partial_syncs)
public:
ThreadPartialSyncInserter(
StorageScope sync_scope, const std::unordered_set<const Object *> &syncs,
std::unordered_map<const Object *, int> partial_syncs)
: sync_scope_(sync_scope), syncs_(syncs), partial_syncs_(partial_syncs) {}
Stmt VisitStmt(const Stmt& stmt) final {
if (syncs_.size() == 0) return stmt;
Stmt VisitStmt(const Stmt &stmt) final {
if (syncs_.size() == 0)
return stmt;
if (syncs_.count(stmt.get())) {
Stmt barrier;
if (partial_syncs_.count(stmt.get())) {
auto iter = partial_syncs_.find(stmt.get());
ICHECK(sync_scope_.rank == StorageRank::kShared);
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(), {iter->second}));
barrier = Evaluate(Call(DataType::Int(32), tl::SyncThreadsPartialOp(),
{iter->second}));
} else {
return StmtExprMutator::VisitStmt(stmt);
}
......@@ -334,11 +345,11 @@ class ThreadPartialSyncInserter : public StmtExprMutator {
}
}
private:
private:
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
const std::unordered_map<const Object*, int>& partial_syncs_;
const std::unordered_set<const Object *> &syncs_;
const std::unordered_map<const Object *, int> &partial_syncs_;
};
Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
......@@ -346,7 +357,8 @@ Stmt ThreadPartialSync(Stmt stmt, std::string storage_scope) {
ThreadPartialSyncPlanner planner(sync_scope);
planner(stmt);
return ThreadPartialSyncInserter(sync_scope, planner.syncs_inserted_,
planner.partial_syncs_inserted_)(std::move(stmt));
planner.partial_syncs_inserted_)(
std::move(stmt));
}
using namespace tir::transform;
......@@ -355,15 +367,16 @@ namespace transform {
Pass ThreadPartialSync(String storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto *n = f.CopyOnWrite();
n->body = tl::ThreadPartialSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ThreadPartialSync", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync").set_body_typed(ThreadPartialSync);
TVM_REGISTER_GLOBAL("tl.transform.ThreadPartialSync")
.set_body_typed(ThreadPartialSync);
} // namespace transform
} // namespace tir
} // namespace tvm
} // namespace transform
} // namespace tl
} // namespace tvm
......@@ -38,22 +38,23 @@ using namespace tir;
enum class Role { kConsumer, kProducer, kBoth };
class WarpSpecializedRoleMarker : public StmtVisitor {
public:
public:
WarpSpecializedRoleMarker(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
Role GetRole(const StmtNode* stmt) const {
Role GetRole(const StmtNode *stmt) const {
auto it = map_.find(stmt);
ICHECK(it != map_.end());
return it->second;
}
Role GetRole(const Stmt& stmt) const { return GetRole(stmt.get()); }
Role GetRole(const Stmt &stmt) const { return GetRole(stmt.get()); }
void VisitStmt_(const EvaluateNode* op) final {
void VisitStmt_(const EvaluateNode *op) final {
Role role = Role::kConsumer;
if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
role = Role::kProducer;
has_bulk_copy_ = true;
}
......@@ -61,8 +62,9 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole(op, role);
}
void VisitStmt_(const BufferStoreNode* op) final {
bool is_shared_store = op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
void VisitStmt_(const BufferStoreNode *op) final {
bool is_shared_store =
op->buffer.scope() == "shared.dyn" || op->buffer.scope() == "shared";
if (!is_shared_store) {
SetRole(op, Role::kConsumer);
return;
......@@ -80,11 +82,12 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
break;
}
}
if (role == Role::kProducer) has_simt_copy_ = true;
if (role == Role::kProducer)
has_simt_copy_ = true;
SetRole(op, role);
}
void VisitStmt_(const SeqStmtNode* op) final {
void VisitStmt_(const SeqStmtNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->seq[0]);
for (auto stmt : op->seq) {
......@@ -96,41 +99,41 @@ class WarpSpecializedRoleMarker : public StmtVisitor {
SetRole(op, role);
}
void VisitStmt_(const IfThenElseNode* op) final {
void VisitStmt_(const IfThenElseNode *op) final {
StmtVisitor::VisitStmt_(op);
auto role = GetRole(op->then_case);
if (op->else_case.defined()) {
auto role_else = GetRole(op->else_case.value());
if (role != role_else) role = Role::kBoth;
if (role != role_else)
role = Role::kBoth;
}
SetRole(op, role);
}
void VisitStmt_(const BlockRealizeNode* op) final {
void VisitStmt_(const BlockRealizeNode *op) final {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->block));
}
template <class NodeType>
void HandleBodyStmt(const NodeType* op) {
template <class NodeType> void HandleBodyStmt(const NodeType *op) {
StmtVisitor::VisitStmt_(op);
SetRole(op, GetRole(op->body));
}
void VisitStmt_(const ForNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode* op) final { HandleBodyStmt(op); }
void VisitStmt_(const ForNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const LetStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AttrStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const AssertStmtNode *op) final { HandleBodyStmt(op); }
void VisitStmt_(const BlockNode *op) final { HandleBodyStmt(op); }
bool HasProducer() { return has_simt_copy_ || has_bulk_copy_; }
bool HasSimtCopy() { return has_simt_copy_; }
private:
void SetRole(const StmtNode* stmt, Role role) { map_[stmt] = role; }
private:
void SetRole(const StmtNode *stmt, Role role) { map_[stmt] = role; }
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const StmtNode*, Role> map_;
std::unordered_map<const StmtNode *, Role> map_;
bool has_simt_copy_ = false;
bool has_bulk_copy_ = false;
};
......@@ -140,23 +143,26 @@ static PrimExpr makeGetBarrier(PrimExpr barrier_id) {
}
static Stmt makeExpectTX(PrimExpr barrier_id, PrimExpr bytes) {
auto call = Call(DataType::Handle(), MBarrierExpectTX(), {makeGetBarrier(barrier_id), bytes});
auto call = Call(DataType::Handle(), MBarrierExpectTX(),
{makeGetBarrier(barrier_id), bytes});
return Evaluate(call);
}
static Stmt makeArriveBarrier(PrimExpr barrier_id) {
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(), {makeGetBarrier(barrier_id)});
auto call = Call(DataType::Handle(), builtin::ptx_arrive_barrier(),
{makeGetBarrier(barrier_id)});
return Evaluate(call);
}
static Stmt makeCpAsyncBarrier(PrimExpr barrier_id) {
auto call =
Call(DataType::Handle(), builtin::ptx_cp_async_barrier(), {makeGetBarrier(barrier_id)});
auto call = Call(DataType::Handle(), builtin::ptx_cp_async_barrier(),
{makeGetBarrier(barrier_id)});
return Evaluate(call);
}
static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
auto call = Call(DataType::Handle(), MBarrierWaitParity(), {makeGetBarrier(barrier_id), parity});
auto call = Call(DataType::Handle(), MBarrierWaitParity(),
{makeGetBarrier(barrier_id), parity});
return Evaluate(call);
}
......@@ -177,7 +183,7 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
// }
class ProducerTraitsCollector : public StmtExprVisitor {
public:
public:
ProducerTraitsCollector() { Clear(); }
void Clear() {
......@@ -192,8 +198,8 @@ class ProducerTraitsCollector : public StmtExprVisitor {
PrimExpr BulkCopyBytes() { return bulk_copy_bytes; }
private:
void VisitExpr_(const CallNode* call) final {
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
......@@ -203,14 +209,14 @@ class ProducerTraitsCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(call);
}
void VisitStmt_(const ForNode* op) final {
void VisitStmt_(const ForNode *op) final {
PrimExpr old_loop_evtents = loop_extents;
loop_extents *= op->extent;
StmtExprVisitor::VisitStmt_(op);
loop_extents = old_loop_evtents;
}
void VisitExpr_(const BufferLoadNode* op) final {
void VisitExpr_(const BufferLoadNode *op) final {
has_simt_copy = true;
StmtExprVisitor::VisitExpr_(op);
}
......@@ -222,15 +228,15 @@ class ProducerTraitsCollector : public StmtExprVisitor {
// Rewrite the producer Stmt to use the correct barrier index
class MbarrierRewriter : public StmtExprMutator {
public:
public:
static Stmt Rewrite(Stmt stmt, PrimExpr barrier_id) {
MbarrierRewriter rewriter;
rewriter.producer_barrier_idx_ = barrier_id;
return rewriter(stmt);
}
private:
PrimExpr VisitExpr_(const CallNode* op) final {
private:
PrimExpr VisitExpr_(const CallNode *op) final {
auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(TMALoadOp()) || call->op.same_as(TMALoadIm2ColOp())) {
Call access_ptr = Downcast<Call>(call->args[2]);
......@@ -242,19 +248,18 @@ class MbarrierRewriter : public StmtExprMutator {
PrimExpr producer_barrier_idx_;
};
class ThreadIdxRewriter : public StmtExprMutator {
public:
public:
static Stmt Rewrite(Stmt stmt, Var thread_var, PrimExpr replaced) {
auto rewriter = ThreadIdxRewriter(thread_var, replaced);
return rewriter(stmt);
}
private:
private:
ThreadIdxRewriter(Var thread_var, PrimExpr replaced)
: thread_var_(thread_var), replaced_(replaced) {}
PrimExpr VisitExpr_(const VarNode* var) final {
PrimExpr VisitExpr_(const VarNode *var) final {
if (var == thread_var_.get()) {
return replaced_;
} else {
......@@ -266,9 +271,12 @@ class ThreadIdxRewriter : public StmtExprMutator {
PrimExpr replaced_;
};
Block MakeGroupBlock(const Stmt& stmt, const Map<String, ObjectRef>& annotations) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ stmt,
/*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/annotations);
Block MakeGroupBlock(const Stmt &stmt,
const Map<String, ObjectRef> &annotations) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ stmt,
/*init=*/{}, /*alloc_buffers=*/{}, /*match_buffers=*/{},
/*annotations=*/annotations);
return block;
}
......@@ -280,11 +288,8 @@ struct PipelineInfo {
std::vector<OpInfo> op_infos;
PipelineInfo() = default;
PipelineInfo(
Array<Array<Integer>> group_info,
Array<Integer> order_info,
Array<Integer> stage_info
) {
PipelineInfo(Array<Array<Integer>> group_info, Array<Integer> order_info,
Array<Integer> stage_info) {
int n = static_cast<int>(group_info.size());
ICHECK(n == static_cast<int>(order_info.size()));
ICHECK(n == static_cast<int>(stage_info.size()));
......@@ -301,7 +306,7 @@ struct PipelineInfo {
}
}
PipelineInfo(const PipelineInfo& other) {
PipelineInfo(const PipelineInfo &other) {
for (auto op_info : other.op_infos) {
op_infos.push_back(op_info);
}
......@@ -364,18 +369,19 @@ struct PipelineInfo {
void PrintPipelineInfo() {
std::cout << "Print op_infos:" << std::endl;
for (size_t i = 0; i < op_infos.size(); i++) {
std::cout << i << " " << op_infos[i].group_size << " " << op_infos[i].order << " " << op_infos[i].stage << std::endl;
std::cout << i << " " << op_infos[i].group_size << " "
<< op_infos[i].order << " " << op_infos[i].stage << std::endl;
}
std::cout << "End of print" << std::endl;
}
};
class GroupOpRewriter : public StmtExprMutator {
public:
public:
GroupOpRewriter(PipelineInfo pipeline_info) : pipeline_info_(pipeline_info) {}
private:
Stmt VisitStmt_(const ForNode* op) final {
private:
Stmt VisitStmt_(const ForNode *op) final {
Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1));
auto original_node = (op->body).as<SeqStmtNode>();
......@@ -385,19 +391,24 @@ class GroupOpRewriter : public StmtExprMutator {
Array<Stmt> new_body;
int cur_id = 0;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
if (pipeline_info_.op_infos[i].group_size == 0) continue;
if (pipeline_info_.op_infos[i].group_size == 0)
continue;
Array<Stmt> block_stmt;
for (int j = 0; j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
for (int j = 0;
j < static_cast<int>(pipeline_info_.op_infos[i].group_size); j++) {
// ICHECK(group_info_[i][j].as<IntImmNode>());
// int index = static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
// int index =
// static_cast<int>(group_info_[i][j].as<IntImmNode>()->value);
ICHECK(original_node->seq[cur_id].as<BlockNode>());
auto block = original_node->seq[cur_id].as<BlockNode>();
// TODO: handle nested seqstmt
block_stmt.push_back(block->body);
cur_id++;
}
new_body.push_back(
MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
}
Array<Integer> order_anno;
Array<Integer> stage_anno;
......@@ -409,24 +420,26 @@ class GroupOpRewriter : public StmtExprMutator {
for_annotations.erase("tl_pipeline_group");
for_annotations.Set("software_pipeline_order", order_anno);
for_annotations.Set("software_pipeline_stage", stage_anno);
For new_for = For(op->loop_var, op->min, op->extent, op->kind, new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)), op->thread_binding, for_annotations);
For new_for =
For(op->loop_var, op->min, op->extent, op->kind,
new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body)),
op->thread_binding, for_annotations);
return new_for;
}
PipelineInfo pipeline_info_;
};
class WSCodeEmitter : public StmtMutator {
public:
public:
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer, const WarpSpecializedRoleMarker& marker)
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer),
marker_(marker),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var) {}
private:
template <typename NodeType>
Stmt FilterByRole(const NodeType* op) {
private:
template <typename NodeType> Stmt FilterByRole(const NodeType *op) {
Role role = marker_.GetRole(op);
if (role == Role::kBoth)
return StmtMutator::VisitStmt_(op);
......@@ -437,7 +450,7 @@ class WSCodeEmitter : public StmtMutator {
}
// TODO: only need to add block for ops in the loop
Stmt VisitStmt_(const SeqStmtNode* op) final {
Stmt VisitStmt_(const SeqStmtNode *op) final {
bool has_producer = false;
for (auto stmt : op->seq) {
if (marker_.GetRole(stmt) == Role::kProducer) {
......@@ -445,19 +458,24 @@ class WSCodeEmitter : public StmtMutator {
break;
}
}
bool need_producer_sync = has_producer && marker_.GetRole(op) == Role::kBoth;
if (!need_producer_sync) return FilterByRole(op);
bool need_producer_sync =
has_producer && marker_.GetRole(op) == Role::kBoth;
if (!need_producer_sync)
return FilterByRole(op);
auto seq_transformed = op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto seq_transformed =
op->seq.Map([&](Stmt stmt) { return VisitStmt(stmt); });
auto map = ExtractSyncPattern(op->seq);
// std::cout << "Print ExtractSyncPattern" << std::endl;
// for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " " << map.release_after[i] << std::endl;
// std::cout << i << " " << map.acquire[i] << " " << map.release[i] << " "
// << map.release_after[i] << std::endl;
// }
// std::cout << "Print sync pattern" << std::endl;
// for (auto pattern : map.patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
// std::cout << "End of ExtractSyncPattern" << std::endl;
// pipeline_info_.PrintPipelineInfo();
......@@ -465,29 +483,38 @@ class WSCodeEmitter : public StmtMutator {
Map<String, ObjectRef> annotations;
annotations.Set(String("stmt_group"), Integer(1));
if (is_emitting_producer_) { // producer case
if (is_emitting_producer_) { // producer case
ProducerTraitsCollector collector;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kConsumer) continue;
if (marker_.GetRole(op->seq[i]) == Role::kConsumer)
continue;
if (marker_.GetRole(op->seq[i]) == Role::kBoth) {
block_stmt.push_back(seq_transformed[i]);
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
new_body.push_back(MakeGroupBlock(
block_stmt.size() == 1 ? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
continue;
}
if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity =
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_;
PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity = map.is_loop_dependency(map.acquire[i])
? bitwise_xor(parity_, 1)
: parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
}
ICHECK(map.release[i] >= 0);
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i];
auto stmt = MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.release[i];
auto stmt =
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt);
if (!is_zero(collector.BulkCopyBytes())) {
auto expect_tx = IfThenElse(EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
auto expect_tx = IfThenElse(
EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(expect_tx);
}
block_stmt.push_back(stmt);
......@@ -497,39 +524,53 @@ class WSCodeEmitter : public StmtMutator {
if (map.release_after[i]) {
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]);
released_barrier_.insert(j + num_barriers_ +
num_stages_ * map.release[i]);
}
}
collector.Clear();
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
}
} else { // consumer case
} else { // consumer case
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kProducer) continue;
if (marker_.GetRole(op->seq[i]) == Role::kProducer)
continue;
if (map.acquire[i] != -1) {
PrimExpr acquire_barrier_id = stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity =
map.is_loop_dependency(map.acquire[i]) ? bitwise_xor(parity_, 1) : parity_;
PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i];
PrimExpr parity = map.is_loop_dependency(map.acquire[i])
? bitwise_xor(parity_, 1)
: parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
}
block_stmt.push_back(seq_transformed[i]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ?
// block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
if (map.release_after[i]) {
PrimExpr release_barrier_id = stage_ + num_barriers_ + num_stages_ * map.release[i];
PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.release[i];
block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) {
released_barrier_.insert(j + num_barriers_ + num_stages_ * map.release[i]);
released_barrier_.insert(j + num_barriers_ +
num_stages_ * map.release[i]);
}
// Update the pipeline info
// Todo: handle sync
}
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations));
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0]
: SeqStmt(std::move(block_stmt)),
annotations));
}
// Filter out the producer stmts
int cur_id = 0;
PipelineInfo new_pipeline_info;
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size()); i++) {
for (int i = 0; i < static_cast<int>(pipeline_info_.op_infos.size());
i++) {
auto op_info = pipeline_info_.op_infos[i];
bool is_producer = false;
for (int j = 0; j < op_info.group_size; j++) {
......@@ -553,7 +594,7 @@ class WSCodeEmitter : public StmtMutator {
return new_body.size() == 1 ? new_body[0] : SeqStmt(std::move(new_body));
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt VisitStmt_(const ForNode *op) final {
int num_stages = 1;
auto num_stages_anno = op->annotations.Get("num_stages");
if (num_stages_anno.defined()) {
......@@ -565,7 +606,7 @@ class WSCodeEmitter : public StmtMutator {
Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array;
Array<Integer> stage_info_array;
auto group_anno = op->annotations.Get("tl_pipeline_group");
if (group_anno.defined()) {
group_info_array = Downcast<Array<Array<Integer>>>(group_anno);
......@@ -579,9 +620,11 @@ class WSCodeEmitter : public StmtMutator {
stage_info_array = Downcast<Array<Integer>>(stage_anno);
}
PipelineInfo pipeline_info(group_info_array, order_info_array, stage_info_array);
PipelineInfo pipeline_info(group_info_array, order_info_array,
stage_info_array);
if (pipeline_info.op_infos.size() > 0) {
ICHECK(pipeline_info_.op_infos.size() == 0) << "Nested pipeline not supported.";
ICHECK(pipeline_info_.op_infos.size() == 0)
<< "Nested pipeline not supported.";
}
PrimExpr parity_before = std::move(parity_);
......@@ -592,13 +635,15 @@ class WSCodeEmitter : public StmtMutator {
num_stages_ = num_stages;
pipeline_info_ = pipeline_info;
stage_ = FloorMod(op->loop_var - op->min, num_stages);
parity_ =
FloorMod(parity_before * op->extent + FloorDiv(op->loop_var - op->min, num_stages), 2);
parity_ = FloorMod(parity_before * op->extent +
FloorDiv(op->loop_var - op->min, num_stages),
2);
auto result = FilterByRole(op);
Stmt grouped_for_node;
if (result.as<ForNode>() && group_anno.defined() && group_info_array.size() > 0 && !is_emitting_producer_) {
if (result.as<ForNode>() && group_anno.defined() &&
group_info_array.size() > 0 && !is_emitting_producer_) {
GroupOpRewriter group_op_rewriter(pipeline_info_);
auto for_node = Downcast<For>(result);
grouped_for_node = group_op_rewriter(for_node);
......@@ -618,7 +663,8 @@ class WSCodeEmitter : public StmtMutator {
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_order");
for_node.CopyOnWrite()->annotations.erase("tl_pipeline_stage");
}
if (is_emitting_producer_ || !group_anno.defined() ||group_info_array.size() == 0) {
if (is_emitting_producer_ || !group_anno.defined() ||
group_info_array.size() == 0) {
return for_node;
}
return grouped_for_node;
......@@ -626,17 +672,17 @@ class WSCodeEmitter : public StmtMutator {
return result;
}
Stmt VisitStmt_(const IfThenElseNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const EvaluateNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AttrStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BufferStoreNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const LetStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AssertStmtNode* op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BlockNode* op) final {
Stmt VisitStmt_(const IfThenElseNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const EvaluateNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AttrStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BufferStoreNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const LetStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const AssertStmtNode *op) final { return FilterByRole(op); }
Stmt VisitStmt_(const BlockNode *op) final {
ICHECK(0);
return Stmt();
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
Stmt VisitStmt_(const BlockRealizeNode *op) final {
ICHECK(0);
return Stmt();
}
......@@ -656,27 +702,32 @@ class WSCodeEmitter : public StmtMutator {
}
};
std::vector<SyncPattern> CreateBaseSyncPairs(Array<Stmt> seq_stmt,
const std::vector<bool>& is_producer) {
std::vector<SyncPattern>
CreateBaseSyncPairs(Array<Stmt> seq_stmt,
const std::vector<bool> &is_producer) {
const int n = seq_stmt.size();
std::vector<std::set<const BufferNode*>> reads, writes;
std::vector<std::set<const BufferNode *>> reads, writes;
reads.reserve(n);
writes.reserve(n);
for (int i = 0; i < n; i++) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"",
/*body*/ seq_stmt[i]);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode*> read_set, write_set;
for (auto region : access[0]) read_set.insert(region->buffer.get());
for (auto region : access[1]) write_set.insert(region->buffer.get());
std::set<const BufferNode *> read_set, write_set;
for (auto region : access[0])
read_set.insert(region->buffer.get());
for (auto region : access[1])
write_set.insert(region->buffer.get());
reads.push_back(std::move(read_set));
writes.push_back(std::move(write_set));
}
auto intersect_fn = [](const std::set<const BufferNode*>& lhs,
const std::set<const BufferNode*>& rhs) {
auto intersect_fn = [](const std::set<const BufferNode *> &lhs,
const std::set<const BufferNode *> &rhs) {
for (auto ptr : lhs)
if (rhs.count(ptr)) return true;
if (rhs.count(ptr))
return true;
return false;
};
......@@ -686,7 +737,8 @@ class WSCodeEmitter : public StmtMutator {
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) {
(intersect_fn(writes[i], reads[j]) ||
intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j});
break;
}
......@@ -701,7 +753,8 @@ class WSCodeEmitter : public StmtMutator {
for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++) {
if (is_producer[i] != is_producer[j] &&
(intersect_fn(writes[i], reads[j]) || intersect_fn(reads[i], writes[j]))) {
(intersect_fn(writes[i], reads[j]) ||
intersect_fn(reads[i], writes[j]))) {
sync_patterns.push_back({i, j});
break;
}
......@@ -712,8 +765,9 @@ class WSCodeEmitter : public StmtMutator {
return sync_patterns;
}
static std::vector<SyncPattern> RemoveUnusedSyncPatterns(
const std::vector<SyncPattern>& sync_patterns, const std::vector<bool>& is_producer) {
static std::vector<SyncPattern>
RemoveUnusedSyncPatterns(const std::vector<SyncPattern> &sync_patterns,
const std::vector<bool> &is_producer) {
/*
Simplify multiple release-acquire pairs into one
------------------
......@@ -746,7 +800,8 @@ class WSCodeEmitter : public StmtMutator {
std::vector<SyncPattern> sync_pattern_cleaned;
sync_pattern_cleaned.reserve(M);
for (int i = 0; i < M; i++)
if (!removed[i]) sync_pattern_cleaned.push_back(sync_patterns[i]);
if (!removed[i])
sync_pattern_cleaned.push_back(sync_patterns[i]);
return sync_pattern_cleaned;
}
......@@ -760,10 +815,12 @@ class WSCodeEmitter : public StmtMutator {
}
auto sync_patterns_base = CreateBaseSyncPairs(seq_stmt, is_producer);
auto sync_patterns = RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
auto sync_patterns =
RemoveUnusedSyncPatterns(sync_patterns_base, is_producer);
// for (auto pattern : sync_patterns) {
// std::cout << pattern.release_idx << " " << pattern.acquire_idx << std::endl;
// std::cout << pattern.release_idx << " " << pattern.acquire_idx <<
// std::endl;
// }
SyncPatternMap map;
......@@ -799,7 +856,7 @@ class WSCodeEmitter : public StmtMutator {
const bool is_emitting_producer_;
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_set<int> released_barrier_;
const WarpSpecializedRoleMarker& marker_;
const WarpSpecializedRoleMarker &marker_;
int num_barriers_ = 0;
PrimExpr parity_ = 0;
......@@ -811,17 +868,18 @@ class WSCodeEmitter : public StmtMutator {
};
class WarpSpecializedRewriter : public StmtExprMutator {
public:
public:
static PrimFunc Substitute(PrimFunc f) {
auto T = WarpSpecializedRewriter();
T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_) T.buffer_data_to_buffer_.Set(buffer->data, buffer);
for (auto [buffer, _] : T.buffer_lca_)
T.buffer_data_to_buffer_.Set(buffer->data, buffer);
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
private:
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_iv_ = Downcast<IterVar>(op->node);
......@@ -839,9 +897,10 @@ class WarpSpecializedRewriter : public StmtExprMutator {
}
}
// If users define a thread binding, we will replace the thread binding with threadIdx.x
// We require the thread binding is threadIdx.x, and the extent is the same as the thread extent
Stmt VisitStmt_(const ForNode* op) final {
// If users define a thread binding, we will replace the thread binding with
// threadIdx.x We require the thread binding is threadIdx.x, and the extent is
// the same as the thread extent
Stmt VisitStmt_(const ForNode *op) final {
ICHECK(thread_iv_.defined());
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (for_node->kind == ForKind::kThreadBinding) {
......@@ -849,14 +908,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
String thread_tag = for_node->thread_binding.value()->thread_tag;
ICHECK(thread_tag == "threadIdx.x") << "Only support threadIdx.x";
Var thread_iv = Downcast<Var>(for_node->loop_var);
Stmt new_body = ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
Stmt new_body =
ThreadIdxRewriter::Rewrite(for_node->body, thread_iv, thread_iv_);
return new_body;
}
return for_node;
}
Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize block_realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
Stmt VisitStmt_(const BlockRealizeNode *op) final {
BlockRealize block_realize =
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
if (!thread_iv_.defined()) {
return block_realize;
}
......@@ -877,17 +938,21 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
if (!marker.HasSimtCopy()) producer_thread_extent = 128;
if (!marker.HasSimtCopy())
producer_thread_extent = 128;
// TODO: estimate the correct reg usage.
auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1}));
auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0}));
auto inc_reg_stmt =
Evaluate(Call(DataType::Handle(), SetMaxNReg(), {240, 1}));
auto dec_reg_stmt =
Evaluate(Call(DataType::Handle(), SetMaxNReg(), {24, 0}));
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
producer_code = ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent);
producer_code =
ThreadIdxRewriter::Rewrite(producer_code, thread_iv_->var,
thread_iv_->var - consumer_thread_extent);
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
need_update_thread_extent_ = true;
......@@ -897,15 +962,16 @@ class WarpSpecializedRewriter : public StmtExprMutator {
Array<PrimExpr> barrier_num_threads;
barrier_num_threads.reserve(num_barriers);
for (int i = 0; i < num_barriers; i++) {
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i) ? producer_thread_extent : consumer_thread_extent;
PrimExpr arrive_thread_count = producer.released_barrier_.count(i)
? producer_thread_extent
: consumer_thread_extent;
barrier_num_threads.push_back(arrive_thread_count);
}
Stmt init_barrier =
Evaluate(Call(DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
Stmt body =
IfThenElse(GE(thread_iv_->var, consumer_thread_extent), producer_code, consumer_code);
Stmt init_barrier = Evaluate(Call(
DataType::Handle(), CreateListofMBarrierOp(), barrier_num_threads));
Stmt body = IfThenElse(GE(thread_iv_->var, consumer_thread_extent),
producer_code, consumer_code);
// Add an attr here to handle the partial thread count in THreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)};
......@@ -935,7 +1001,8 @@ tvm::transform::Pass WarpSpecialized() {
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized").set_body_typed(WarpSpecialized);
TVM_REGISTER_GLOBAL("tl.transform.WarpSpecialized")
.set_body_typed(WarpSpecialized);
} // namespace tl
} // namespace tvm
} // namespace tl
} // namespace tvm
......@@ -5,7 +5,6 @@ import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
......
......@@ -30,13 +30,11 @@ def matmul(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -169,9 +167,7 @@ def test_gemm_f32f32f32_nn():
def test_gemm_i8i8i32_nn():
run_gemm(
512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64
)
run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64)
def test_gemm_f16f16f16_tn():
......@@ -217,9 +213,7 @@ def test_gemm_i8i8i32_tn():
def test_gemm_f64f64f64_nt():
run_gemm(
512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16
)
run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16)
def test_gemm_f32f32f32_nt():
......
......@@ -10,8 +10,7 @@ import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
......
......@@ -6,6 +6,7 @@ import tilelang.testing
import tilelang as tl
from tilelang import primitives as P
def matmul_ssr(
M,
N,
......@@ -30,13 +31,11 @@ def matmul_ssr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
......@@ -145,13 +144,11 @@ def matmul_rsr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......@@ -264,13 +261,11 @@ def matmul_rrr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads
) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
......
......@@ -11,8 +11,7 @@ from .mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import get_swizzle_layout # noqa: F401
from .mma_layout import make_mma_swizzle_layout # noqa: F401
from .mfma_layout import make_mfma_swizzle_layout # noqa: F401
......@@ -14,6 +14,7 @@ from .utils import (
lift = convert
# TODO(lei): Add Typing for this file
class TensorCoreIntrinEmitter(object):
"""
......@@ -75,9 +76,11 @@ class TensorCoreIntrinEmitter(object):
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self.num_elems_per_byte = num_elems_per_byte
if self.warp_rows == 0 or self.warp_cols == 0:
raise ValueError(f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}")
raise ValueError(
f"Invalid threads configuration for this tile shape, {self.warp_rows} x {self.warp_cols} with threads {self.threads}"
)
def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
......@@ -272,12 +275,9 @@ class TensorCoreIntrinEmitter(object):
A_local_buf.data,
k_inner * warp_rows * local_size_a + i * local_size_a,
B_local_buf.data,
k_inner * warp_cols * local_size_b + j * local_size_b
+ lift(local_size_b) // 2,
k_inner * warp_cols * local_size_b + j * local_size_b + lift(local_size_b) // 2,
C_local_buf.data,
i * warp_cols * local_size_out
+ j * local_size_out
+ lift(local_size_out) // 2,
i * warp_cols * local_size_out + j * local_size_out + lift(local_size_out) // 2,
T.bool(False),
)
......@@ -328,7 +328,9 @@ class TensorCoreIntrinEmitter(object):
return (_warp_stmatrix_global(C_local_buf, C_buf, thread_bindings)
if is_global else _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings))
def make_mma_load_layout(self, local_buf: Buffer, matrix:Literal["A", "B"]="A") -> T.Fragment:
def make_mma_load_layout(self,
local_buf: Buffer,
matrix: Literal["A", "B"] = "A") -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -372,12 +374,14 @@ class TensorCoreIntrinEmitter(object):
elif matrix == "A" and not transposed:
transform_func = ldmatrix_16x32_to_shared_16x32_layout_a
else:
raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8")
raise ValueError(
"ldmatrix only supports B transposed and A non-transposed for int8")
else:
raise ValueError(f"Unsupported dtype {dtype}")
shape = local_buf.shape
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(local_buf.scope())
assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
local_buf.scope())
if matrix == "A":
micro_size_x, micro_size_y = self.micro_size_x, self.micro_size_k
......@@ -397,7 +401,8 @@ class TensorCoreIntrinEmitter(object):
transform_func = transform_func if not transposed else transform_func_trans
warp_size, local_size_a, local_size_b = self.WARP_SIZE, self.local_size_a, self.local_size_b
local_size = local_size_a if matrix == "A" else local_size_b
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32").inverse([warp_size, local_size])
inverse_mma_load_layout = IndexMap.from_func(
transform_func, index_dtype="int32").inverse([warp_size, local_size])
def forward_thread(i: int, j: int) -> int:
"""
......@@ -406,29 +411,19 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_load_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = (
block_i * (block_col_warps * warp_cols)
+ block_j * warp_rows
+ warp_i * warp_cols
+ warp_j
)
block_i * (block_col_warps * warp_cols) + block_j * warp_rows +
warp_i * warp_cols + warp_j)
else:
thread_id = (
block_j * (block_row_warps * warp_size)
+ block_i * warp_size
+ lane_id
)
block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id)
return thread_id
def forward_index(i: int, j: int) -> int:
......@@ -439,21 +434,13 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_load_layout.map_indices([mma_i, mma_j])
return (
warp_i * (warp_cols * local_size_out)
+ warp_j * local_size_out
+ local_id
)
return (warp_i * (warp_cols * local_size_out) + warp_j * local_size_out + local_id)
fragment = T.Fragment(
shape,
......@@ -465,9 +452,7 @@ class TensorCoreIntrinEmitter(object):
print(f"fragment.index: {fragment.index}")
return fragment
def make_mma_store_layout(
self, local_buf: Buffer
) -> T.Fragment:
def make_mma_store_layout(self, local_buf: Buffer) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -500,6 +485,7 @@ class TensorCoreIntrinEmitter(object):
warp_rows, warp_cols = self.warp_rows, self.warp_cols
warp_size = self.WARP_SIZE
is_m_first = self.is_m_first
def forward_thread(i: int, j: int) -> int:
"""
Given the row index `i` and column index `j` in the fragment,
......@@ -514,7 +500,8 @@ class TensorCoreIntrinEmitter(object):
mma_i, mma_j = i % micro_size_x, j % micro_size_y
lane_id, _ = inverse_mma_store_layout.map_indices([mma_i, mma_j])
if is_m_first:
thread_id = block_i * (block_col_warps * warp_cols) + block_j * warp_rows + warp_i * warp_cols + warp_j
thread_id = block_i * (
block_col_warps * warp_cols) + block_j * warp_rows + warp_i * warp_cols + warp_j
else:
thread_id = block_j * (block_row_warps * warp_size) + block_i * warp_size + lane_id
return thread_id
......@@ -527,13 +514,9 @@ class TensorCoreIntrinEmitter(object):
"""
# the upper bounds of i and j are block_row_warps * warp_rows * micro_size_x and block_col_warps * warp_cols * micro_size_y
# the upper bounds of block_row_warps and block_col_warps are warp_rows and warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (
j // micro_size_y
) // warp_cols
block_i, block_j = (i // micro_size_x) // warp_rows, (j // micro_size_y) // warp_cols
# the upper bounds of warp_i and warp_j are warp_rows and warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (
j // micro_size_y
) % warp_cols
warp_i, warp_j = (i // micro_size_x) % warp_rows, (j // micro_size_y) % warp_cols
# upper bounds of mma_i and mma_j are micro_size_x and micro_size_y
mma_i, mma_j = i % micro_size_x, j % micro_size_y
_, local_id = inverse_mma_store_layout.map_indices([mma_i, mma_j])
......@@ -545,6 +528,7 @@ class TensorCoreIntrinEmitter(object):
forward_index_fn=forward_index,
)
class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
......
......@@ -4,6 +4,7 @@
from tvm.script import tir as T
def alloc_shared(shape, dtype, scope="shared.dyn"):
return T.alloc_buffer(shape, dtype, scope=scope)
......
......@@ -6,11 +6,10 @@ from typing import Union, List, Optional
from tvm import tir
from tvm.script import tir as T
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
access_type = {"r": 1, "w": 2, "rw": 3}[access_type]
return tir.call_intrin(
"handle", tir.op.Op.get("tl.region"), buffer, access_type, *args
)
return tir.call_intrin("handle", tir.op.Op.get("tl.region"), buffer, access_type, *args)
def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
......@@ -19,20 +18,14 @@ def buffer_to_tile_region(buffer: tir.Buffer, access_type: str):
return region(T.BufferLoad(buffer, mins), access_type, *extents)
def buffer_load_to_tile_region(
load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]
):
def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: List[tir.PrimExpr]):
return region(load, access_type, *extents)
def buffer_region_to_tile_region(
buffer_region: tir.BufferRegion, access_type: str
):
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str):
mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region]
return region(
T.BufferLoad(buffer_region.buffer, mins), access_type, *extents
)
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents)
def copy(
......@@ -71,9 +64,7 @@ def copy(
src = _to_region(src, "r")
dst = _to_region(dst, "w")
if coalesced_width is not None:
return tir.call_intrin(
"handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width
)
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width)
else:
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst)
......
......@@ -10,12 +10,8 @@ def atomic_add(dst, value):
def atomic_addx2(dst, value):
return T.call_extern(
"handle", "atomicAddx2", T.address_of(dst), T.address_of(value)
)
return T.call_extern("handle", "atomicAddx2", T.address_of(dst), T.address_of(value))
def dp4a(A, B, C):
return T.call_extern(
"handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)
)
return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment