Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap));
Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr);
}
......@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return nullptr;
}
Var GetVarFromAccessPtr(const PrimExpr& expr) {
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
......@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false;
if (!is_zero(ranges_[i]->min))
return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
return false;
}
return true;
}
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0);
}
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; }
Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; }
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
} // namespace tl
} // namespace tvm
......@@ -25,17 +25,19 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); })
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel {
kFree = 0,
......@@ -64,30 +66,31 @@ struct CanonializeArgs {
};
class Operator {
public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level);
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
};
class RegionOp : public Operator {
public:
public:
RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get();
static const Op &Get();
const Buffer& GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; }
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
private:
private:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
};
Var GetVarFromAccessPtr(const PrimExpr& expr);
Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
......
......@@ -39,21 +39,22 @@ using namespace tir;
namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
}
} // namespace attr
class IfBufferRemapLoopGenerator : public StmtExprMutator {
public:
public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt)));
}
private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map)
private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) {
......@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return load;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
......@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_;
};
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) {
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
......@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
......@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; });
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
}
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (loop_layout_.defined()) return {};
if (level == InferLevel::kStrict) return {};
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined())
return {};
if (level == InferLevel::kStrict)
return {};
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) {
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer))
......@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer = buffer;
}
}
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) {
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) {
return src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar);
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep);
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
}
};
......@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate();
if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0);
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
......@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} else {
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto* imm = coalesced_width.as<IntImmNode>()) {
if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value;
// Verify that vector_size is divisible by expected
if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width "
<< expected;
LOG(FATAL) << "Vector size " << vector_size
<< " is not divisible by coalesced width " << expected;
}
vector_size = expected;
} else {
......@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size)))
if (!analyzer_.CanProveEqual(loop_thread_extent,
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else {
return {};
}
// Step 2: Check that the loop's partition can correctly align with all source fragment
for (const auto& [buffer, _] : indice_map_) {
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent()))
if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); });
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer
ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
}
}
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results;
for (const auto& [buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer));
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer))
results.Set(buffer, CompleteBufferFragment(buffer));
}
return results;
}
......@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
}
}
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_;
if (IsCommonAccessIndice(buffer))
return loop_layout_;
PrimExpr rep_b =
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
......@@ -242,7 +262,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
}
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
......
......@@ -23,30 +23,30 @@ using namespace tir;
class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor {
private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){};
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const BufferStoreNode* op) final;
void VisitExpr_(const BufferLoadNode* op) final;
private:
ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode *op) final;
ParallelOp* p;
ParallelOp *p;
friend class ParallelOp;
};
class ParallelOp : public Operator {
public:
public:
ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const;
private:
Fragment CompleteBufferFragment(const Buffer& buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const;
private:
Fragment CompleteBufferFragment(const Buffer &buffer);
bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
......
......@@ -54,7 +54,7 @@ PrimExpr ReduceOp::MakeInitValue() const {
}
}
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const {
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs);
......@@ -90,8 +90,9 @@ std::string ReduceOp::MakeCodegenReducer() const {
}
}
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment")
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst];
......@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar));
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
}
Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]),
Var("rv"), IterVarType::kDataPar});
Array<PrimExpr> src_indices =
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices =
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
src_vars.insert(src_vars.begin() + this->dim,
{Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
IterVarType::kDataPar});
Array<PrimExpr> src_indices = src_layout->Forward(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices = dst_layout->Forward(
dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts;
// make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
if (this->clear)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
......@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr;
IterVar var;
std::tie(expr, var) =
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var);
}
Stmt reduce_local = BufferStore(dst_buffer,
Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}});
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, NullOpt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
stmts.push_back(reduce_local);
// make inter-thread reduce
PrimExpr src_thread =
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto& iter_split : iter_sum->args) {
PrimExpr src_thread = src_layout->ForwardThread(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum =
arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue;
if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", "
<< (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()),
BufferLoad(dst_buffer, dst_indices)};
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
}
}
......@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body);
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body);
}
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body;
}
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (level >= InferLevel::kStrict) return {};
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
......@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd.push_back(InputPlaceholder(i - 1));
}
}
auto thd =
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar();
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
return {{dst, dst_layout}};
}
return {};
......@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -18,13 +18,13 @@ namespace tl {
using namespace tir;
class ReduceOp : public Operator {
public:
public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
private:
private:
tir::Buffer src, dst;
int dim;
enum class ReduceType {
......@@ -36,7 +36,7 @@ class ReduceOp : public Operator {
bool clear;
PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const;
PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const;
};
......
......@@ -17,12 +17,12 @@ namespace tl {
using namespace runtime;
template <typename T>
static std::string ArrayToStr(const T* ptr, size_t n) {
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss;
ss << "[";
for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", ";
if (i > 0)
ss << ", ";
ss << ptr[i];
}
ss << "]";
......@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
}
struct TensorMapArgs {
CUtensorMap* map;
CUtensorMap *map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
void *globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave;
......@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
......@@ -63,10 +64,14 @@ struct TensorMapArgs {
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
}
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
......@@ -79,7 +84,8 @@ struct TensorMapArgs {
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
......@@ -89,23 +95,26 @@ struct TensorMapArgs {
};
// set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
struct TensorMapIm2ColArgs {
CUtensorMap* map;
CUtensorMap *map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
void *globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
......@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
......@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
}
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
......@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "pixelBoxLowerCorner "
<< ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
......@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
}
};
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1,
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
} // namespace tl
} // namespace tvm
......@@ -13,8 +13,10 @@
namespace tvm {
namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled";
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col";
constexpr const char *tvm_tensormap_create_tiled =
"__tvm_tensormap_create_tiled";
constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col";
} // namespace tl
} // namespace tvm
......
This diff is collapsed.
......@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC {
public:
public:
CodeGenTileLangCUDA();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string &scope,
std::ostream &os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string &vec, DataType t, int i,
std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string &vec, DataType t, int i,
const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
void AddFunction(const PrimFunc &f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// whether need mma.h
......@@ -77,12 +85,14 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t,
const VarNode *variable, std::ostream &os);
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
};
} // namespace codegen
......
This diff is collapsed.
......@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC {
public:
public:
CodeGenTileLangHIP();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string &scope,
std::ostream &os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string &vec, DataType t, int i,
std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string &vec, DataType t, int i,
const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
void AddFunction(const PrimFunc &f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p);
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p);
// whether need math_constants.h
bool need_math_constants_h_{false};
......
This diff is collapsed.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
......@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
......@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
......@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin";
if (ptx[0] != '/')
fmt = "cubin";
} else {
ICHECK(0);
}
......@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
......@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
return String(code);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug);
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
.set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
.set_body_typed(BuildTLDebug);
} // namespace codegen
} // namespace tvm
This diff is collapsed.
......@@ -11,13 +11,17 @@
namespace tvm {
namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; }
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; }
bool TargetIsCuda(Target target) {
return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
return target->GetTargetDeviceType() == kDLROCM;
}
int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char* arch_str = s.value().c_str();
const char *arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_');
......@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
}
bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 70 && arch < 75;
}
bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 75 && arch < 80;
}
bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 80 && arch < 90;
}
bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false;
if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA
......@@ -78,13 +87,15 @@ bool TargetHasAsyncCopy(Target target) {
return false;
}
bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 75;
}
bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment