Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
...@@ -20,13 +20,14 @@ using namespace tir; ...@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region) TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) { std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value(); Op op = call->op.as<Op>().value();
if (op_map.count(op)) { if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap)); Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr); ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr); return std::unique_ptr<Operator>(ptr);
} }
...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) { ...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return nullptr; return nullptr;
} }
Var GetVarFromAccessPtr(const PrimExpr& expr) { Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
ICHECK(call); ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr())); ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool RegionOp::IsFullRegion() const { bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) { for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false; if (!is_zero(ranges_[i]->min))
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false; return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
return false;
} }
return true; return true;
} }
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method."; ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0); return Evaluate(0);
} }
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; } Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; } LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -25,17 +25,19 @@ using namespace tir; ...@@ -25,17 +25,19 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>; using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \ const Op &Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \ .set_attr<OpBuilderFunc>("TLOpBuilder", \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); }) [](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel { enum class InferLevel {
kFree = 0, kFree = 0,
...@@ -64,30 +66,31 @@ struct CanonializeArgs { ...@@ -64,30 +66,31 @@ struct CanonializeArgs {
}; };
class Operator { class Operator {
public: public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Canonialize(const CanonializeArgs &T,
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level); arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default; virtual ~Operator() = default;
}; };
class RegionOp : public Operator { class RegionOp : public Operator {
public: public:
RegionOp(Array<PrimExpr> args, BufferMap vmap); RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get(); static const Op &Get();
const Buffer& GetBuffer() const { return buffer_; } const Buffer &GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; } const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; } int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const; bool IsFullRegion() const;
private: private:
Buffer buffer_; Buffer buffer_;
Array<Range> ranges_; Array<Range> ranges_;
int access_mask_; int access_mask_;
}; };
Var GetVarFromAccessPtr(const PrimExpr& expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
......
...@@ -39,21 +39,22 @@ using namespace tir; ...@@ -39,21 +39,22 @@ using namespace tir;
namespace attr { namespace attr {
/*! \brief Mark that how the loop is vectorized. */ /*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width"; constexpr const char *coalesced_width = "coalesced_width";
} } // namespace attr
class IfBufferRemapLoopGenerator : public StmtExprMutator { class IfBufferRemapLoopGenerator : public StmtExprMutator {
public: public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap, static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) { Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map); IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt))); return Downcast<For>(generator(std::move(stmt)));
} }
private: private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map) IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {} : buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) { if (buffer_remap_.count(load->buffer)) {
...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return load; return load;
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) { if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices); auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
}; };
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel); ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { ...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { ...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const { bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; }); auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice); return StructuralEqual()(indice_map_[buffer], common_indice);
} }
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined()) return {}; if (loop_layout_.defined())
if (level == InferLevel::kStrict) return {}; return {};
if (level == InferLevel::kStrict)
return {};
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value(); auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer)) if (buffer_is_write_.count(buffer))
...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer = buffer; read_source_buffer = buffer;
} }
} }
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) { auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value(); Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return src_layout; return src_layout;
} else { } else {
Var rep; Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
} }
}; };
...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (read_source_buffer.defined()) { if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated. // Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate(); if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition // if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) { if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse(); auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd; Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0); for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0)); fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back(); auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0)); AddPredicate(EQ(rep, 0));
...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} else { } else {
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined // Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) { if (auto coalesced_width =
if (const auto* imm = coalesced_width.as<IntImmNode>()) { root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value; int expected = imm->value;
// Verify that vector_size is divisible by expected // Verify that vector_size is divisible by expected
if (vector_size % expected != 0) { if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width " LOG(FATAL) << "Vector size " << vector_size
<< expected; << " is not divisible by coalesced width " << expected;
} }
vector_size = expected; vector_size = expected;
} else { } else {
...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size); loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
} }
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size))) if (!analyzer_.CanProveEqual(loop_thread_extent,
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent)); AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else { } else {
return {}; return {};
} }
// Step 2: Check that the loop's partition can correctly align with all source fragment // Step 2: Check that the loop's partition can correctly align with all source
for (const auto& [buffer, _] : indice_map_) { // fragment
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value(); auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases // TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs // need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent())) if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue; continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); }); auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt); auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt); auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs); auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs; << "\nLHS = " << lhs << "\nRHS = " << rhs;
} }
} }
// Step 3: Infer other fragment's layout from the loop's partition // Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results; LayoutMap results;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer)); if (!T.layout_map.count(buffer))
results.Set(buffer, CompleteBufferFragment(buffer));
} }
return results; return results;
} }
...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { ...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
} }
} }
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined()); ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_; if (IsCommonAccessIndice(buffer))
return loop_layout_;
PrimExpr rep_b = PrimExpr rep_b = MakeFlattenedExpression(
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer]; auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b); bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
...@@ -242,7 +262,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { ...@@ -242,7 +262,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
} }
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread( PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt) return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar(); ->CondenseReplicateVar();
......
...@@ -23,30 +23,30 @@ using namespace tir; ...@@ -23,30 +23,30 @@ using namespace tir;
class ParallelOp; class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor { class ParallelLoopNestVisitor : public StmtExprVisitor {
private: private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){}; ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode* op) final; void VisitExpr_(const BufferLoadNode *op) final;
ParallelOp* p; ParallelOp *p;
friend class ParallelOp; friend class ParallelOp;
}; };
class ParallelOp : public Operator { class ParallelOp : public Operator {
public: public:
ParallelOp(For root); ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; } Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; } For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; } Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const; Optional<PrimExpr> GetPredicate(Var thread_var) const;
private: private:
Fragment CompleteBufferFragment(const Buffer& buffer); Fragment CompleteBufferFragment(const Buffer &buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) { void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
......
...@@ -54,7 +54,7 @@ PrimExpr ReduceOp::MakeInitValue() const { ...@@ -54,7 +54,7 @@ PrimExpr ReduceOp::MakeInitValue() const {
} }
} }
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const { PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b; PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) { if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs); rhs = Cast(lhs->dtype, rhs);
...@@ -90,8 +90,9 @@ std::string ReduceOp::MakeCodegenReducer() const { ...@@ -90,8 +90,9 @@ std::string ReduceOp::MakeCodegenReducer() const {
} }
} }
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented."; << "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src]; auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst]; auto dst_buffer = T.buffer_remap[this->dst];
...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array<IterVar> dst_vars; Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) { for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar)); dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
} }
Array<IterVar> src_vars = dst_vars; Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]), src_vars.insert(src_vars.begin() + this->dim,
Var("rv"), IterVarType::kDataPar}); {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
Array<PrimExpr> src_indices = IterVarType::kDataPar});
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> src_indices = src_layout->Forward(
Array<PrimExpr> dst_indices = src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> dst_indices = dst_layout->Forward(
dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts; Array<Stmt> stmts;
// make reduce-init stmt // make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices)); if (this->clear)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce // make thread-local reduce
Array<PrimExpr> src_indice_compressed; Array<PrimExpr> src_indice_compressed;
...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for (size_t i = 0; i < src_layout->OutputDim(); i++) { for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr; PrimExpr expr;
IterVar var; IterVar var;
std::tie(expr, var) = std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer); src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr); src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var); src_var_compressed.push_back(var);
} }
Stmt reduce_local = BufferStore(dst_buffer, Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices), this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)), BufferLoad(src_buffer, src_indice_compressed)),
dst_indices); dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local = reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled, For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}}); ForKind::kUnrolled, reduce_local, NullOpt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
} }
stmts.push_back(reduce_local); stmts.push_back(reduce_local);
// make inter-thread reduce // make inter-thread reduce
PrimExpr src_thread = PrimExpr src_thread = src_layout->ForwardThread(
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {}); src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); auto iter_sum =
for (const auto& iter_split : iter_sum->args) { arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>(); auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined()); ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) { if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale); auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent); auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr); ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue; if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale); int reducing_threads = (*extent) * (*scale);
std::stringstream ss; std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< (*scale) << ">::run"; << reducing_threads << ", " << (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()), Array<PrimExpr> thread_reduce_args = {
BufferLoad(dst_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) { if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype); PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace); thread_reduce_args.push_back(workspace);
} }
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args); auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices)); stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
} }
} }
...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop // make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body); body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body);
} }
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout); body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body; return body;
} }
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict) return {}; if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) { T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value(); auto src_layout = T.layout_map[src].as<Fragment>().value();
...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd.push_back(InputPlaceholder(i - 1)); fwd.push_back(InputPlaceholder(i - 1));
} }
} }
auto thd = auto thd = src_layout->ForwardThread(
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar(); Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
return {{dst, dst_layout}}; return {{dst, dst_layout}};
} }
return {}; return {};
...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(ReduceOp, reduce) TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -18,13 +18,13 @@ namespace tl { ...@@ -18,13 +18,13 @@ namespace tl {
using namespace tir; using namespace tir;
class ReduceOp : public Operator { class ReduceOp : public Operator {
public: public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap); ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
private: private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
enum class ReduceType { enum class ReduceType {
...@@ -36,7 +36,7 @@ class ReduceOp : public Operator { ...@@ -36,7 +36,7 @@ class ReduceOp : public Operator {
bool clear; bool clear;
PrimExpr MakeInitValue() const; PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const; PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const; std::string MakeCodegenReducer() const;
}; };
......
...@@ -17,12 +17,12 @@ namespace tl { ...@@ -17,12 +17,12 @@ namespace tl {
using namespace runtime; using namespace runtime;
template <typename T> template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
static std::string ArrayToStr(const T* ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", "; if (i > 0)
ss << ", ";
ss << ptr[i]; ss << ptr[i];
} }
ss << "]"; ss << "]";
...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) { ...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
} }
struct TensorMapArgs { struct TensorMapArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5]; cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave; CUtensorMapInterleave interleave;
...@@ -45,8 +45,9 @@ struct TensorMapArgs { ...@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs T; TensorMapArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
...@@ -63,10 +64,14 @@ struct TensorMapArgs { ...@@ -63,10 +64,14 @@ struct TensorMapArgs {
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]); T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
} }
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -79,7 +84,8 @@ struct TensorMapArgs { ...@@ -79,7 +84,8 @@ struct TensorMapArgs {
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -89,23 +95,26 @@ struct TensorMapArgs { ...@@ -89,23 +95,26 @@ struct TensorMapArgs {
}; };
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args); TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled( CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString(); << T.ToDebugString();
} }
*ret = static_cast<int>(result); *ret = static_cast<int>(result);
}); });
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5]; cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3]; int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs { ...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs T; TensorMapIm2ColArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs { ...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
} }
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs { ...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl << "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl << "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl << "pixelBoxLowerCorner "
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs { ...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
} }
}; };
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col( CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel, T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString(); << T.ToDebugString();
} }
*ret = static_cast<int>(result); *ret = static_cast<int>(result);
}); });
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled"; constexpr const char *tvm_tensormap_create_tiled =
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_tiled";
constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col";
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
This diff is collapsed.
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC { class CodeGenTileLangCUDA final : public CodeGenC {
public: public:
CodeGenTileLangCUDA(); CodeGenTileLangCUDA();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory // The size of the barrier array in shared memory
int barrier_count_ = -1; int barrier_count_ = -1;
// whether need mma.h // whether need mma.h
...@@ -77,12 +85,14 @@ class CodeGenTileLangCUDA final : public CodeGenC { ...@@ -77,12 +85,14 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy // Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes; std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts; std::unordered_map<const VarNode *, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, CodeGenTileLangCUDA *p);
std::ostream& os); void PrintWmmaScope(const std::string &scope, DataType t,
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); const VarNode *variable, std::ostream &os);
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
}; };
} // namespace codegen } // namespace codegen
......
This diff is collapsed.
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC { class CodeGenTileLangHIP final : public CodeGenC {
public: public:
CodeGenTileLangHIP(); CodeGenTileLangHIP();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p);
// whether need math_constants.h // whether need math_constants.h
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
......
This diff is collapsed.
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) { static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap; std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second); auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info; runtime::FunctionInfo info;
...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co ...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { if (const auto *f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string(); ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin"; if (ptx[0] != '/')
fmt = "cubin";
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
return String(code); return String(code);
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA); TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug); .set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
.set_body_typed(BuildTLDebug);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
This diff is collapsed.
...@@ -11,13 +11,17 @@ ...@@ -11,13 +11,17 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; } bool TargetIsCuda(Target target) {
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; } return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
return target->GetTargetDeviceType() == kDLROCM;
}
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<String>("arch");
ICHECK(s.defined()); ICHECK(s.defined());
const char* arch_str = s.value().c_str(); const char *arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's'); ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm'); ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_'); ICHECK_EQ(arch_str[2], '_');
...@@ -25,31 +29,36 @@ int GetArchInt(Target target) { ...@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
} }
bool TargetIsVolta(Target target) { bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 70 && arch < 75; return arch >= 70 && arch < 75;
} }
bool TargetIsTuring(Target target) { bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75 && arch < 80; return arch >= 75 && arch < 80;
} }
bool TargetIsAmpere(Target target) { bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 80 && arch < 90; return arch >= 80 && arch < 90;
} }
bool TargetIsHopper(Target target) { bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
bool TargetIsCDNA(Target target) { bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false; if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA // if mcpu start with "gfx9", it is CDNA
...@@ -78,13 +87,15 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -78,13 +87,15 @@ bool TargetHasAsyncCopy(Target target) {
return false; return false;
} }
bool TargetHasLdmatrix(Target target) { bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75; return arch >= 75;
} }
bool TargetHasStmatrix(Target target) { bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment