Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
...@@ -20,13 +20,14 @@ using namespace tir; ...@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region) TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) { std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value(); Op op = call->op.as<Op>().value();
if (op_map.count(op)) { if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap)); Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr); ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr); return std::unique_ptr<Operator>(ptr);
} }
...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) { ...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return nullptr; return nullptr;
} }
Var GetVarFromAccessPtr(const PrimExpr& expr) { Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
ICHECK(call); ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr())); ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool RegionOp::IsFullRegion() const { bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) { for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false; if (!is_zero(ranges_[i]->min))
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false; return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
return false;
} }
return true; return true;
} }
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method."; ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0); return Evaluate(0);
} }
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; } Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; } LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -25,17 +25,19 @@ using namespace tir; ...@@ -25,17 +25,19 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>; using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \ const Op &Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \ .set_attr<OpBuilderFunc>("TLOpBuilder", \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); }) [](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel { enum class InferLevel {
kFree = 0, kFree = 0,
...@@ -64,35 +66,36 @@ struct CanonializeArgs { ...@@ -64,35 +66,36 @@ struct CanonializeArgs {
}; };
class Operator { class Operator {
public: public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Canonialize(const CanonializeArgs &T,
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level); arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default; virtual ~Operator() = default;
}; };
class RegionOp : public Operator { class RegionOp : public Operator {
public: public:
RegionOp(Array<PrimExpr> args, BufferMap vmap); RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get(); static const Op &Get();
const Buffer& GetBuffer() const { return buffer_; } const Buffer &GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; } const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; } int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const; bool IsFullRegion() const;
private: private:
Buffer buffer_; Buffer buffer_;
Array<Range> ranges_; Array<Range> ranges_;
int access_mask_; int access_mask_;
}; };
Var GetVarFromAccessPtr(const PrimExpr& expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_OP_H_ #endif // TVM_TL_OP_OP_H_
...@@ -39,21 +39,22 @@ using namespace tir; ...@@ -39,21 +39,22 @@ using namespace tir;
namespace attr { namespace attr {
/*! \brief Mark that how the loop is vectorized. */ /*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width"; constexpr const char *coalesced_width = "coalesced_width";
} } // namespace attr
class IfBufferRemapLoopGenerator : public StmtExprMutator { class IfBufferRemapLoopGenerator : public StmtExprMutator {
public: public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap, static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) { Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map); IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt))); return Downcast<For>(generator(std::move(stmt)));
} }
private: private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map) IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {} : buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) { if (buffer_remap_.count(load->buffer)) {
...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return load; return load;
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) { if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices); auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
}; };
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel); ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { ...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { ...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const { bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; }); auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice); return StructuralEqual()(indice_map_[buffer], common_indice);
} }
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined()) return {}; if (loop_layout_.defined())
if (level == InferLevel::kStrict) return {}; return {};
if (level == InferLevel::kStrict)
return {};
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value(); auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer)) if (buffer_is_write_.count(buffer))
...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer = buffer; read_source_buffer = buffer;
} }
} }
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) { auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value(); Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return src_layout; return src_layout;
} else { } else {
Var rep; Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
} }
}; };
...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (read_source_buffer.defined()) { if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated. // Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate(); if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition // if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) { if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse(); auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd; Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0); for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0)); fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back(); auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0)); AddPredicate(EQ(rep, 0));
...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} else { } else {
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined // Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) { if (auto coalesced_width =
if (const auto* imm = coalesced_width.as<IntImmNode>()) { root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value; int expected = imm->value;
// Verify that vector_size is divisible by expected // Verify that vector_size is divisible by expected
if (vector_size % expected != 0) { if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width " LOG(FATAL) << "Vector size " << vector_size
<< expected; << " is not divisible by coalesced width " << expected;
} }
vector_size = expected; vector_size = expected;
} else { } else {
...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size); loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
} }
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size))) if (!analyzer_.CanProveEqual(loop_thread_extent,
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent)); AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else { } else {
return {}; return {};
} }
// Step 2: Check that the loop's partition can correctly align with all source fragment // Step 2: Check that the loop's partition can correctly align with all source
for (const auto& [buffer, _] : indice_map_) { // fragment
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value(); auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases // TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs // need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent())) if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue; continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); }); auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt); auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt); auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs); auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer ICHECK(is_zero(diff))
<< "\nLHS = " << lhs << "\nRHS = " << rhs; << "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
} }
} }
// Step 3: Infer other fragment's layout from the loop's partition // Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results; LayoutMap results;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer)); if (!T.layout_map.count(buffer))
results.Set(buffer, CompleteBufferFragment(buffer));
} }
return results; return results;
} }
...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { ...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
} }
} }
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined()); ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_; if (IsCommonAccessIndice(buffer))
return loop_layout_;
PrimExpr rep_b = PrimExpr rep_b = MakeFlattenedExpression(
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer]; auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b); bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
...@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { ...@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
} }
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread( PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt) return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar(); ->CondenseReplicateVar();
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,30 +23,30 @@ using namespace tir; ...@@ -23,30 +23,30 @@ using namespace tir;
class ParallelOp; class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor { class ParallelLoopNestVisitor : public StmtExprVisitor {
private: private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){}; ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode* op) final; void VisitExpr_(const BufferLoadNode *op) final;
ParallelOp* p; ParallelOp *p;
friend class ParallelOp; friend class ParallelOp;
}; };
class ParallelOp : public Operator { class ParallelOp : public Operator {
public: public:
ParallelOp(For root); ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; } Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; } For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; } Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const; Optional<PrimExpr> GetPredicate(Var thread_var) const;
private: private:
Fragment CompleteBufferFragment(const Buffer& buffer); Fragment CompleteBufferFragment(const Buffer &buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) { void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
...@@ -66,7 +66,7 @@ class ParallelOp : public Operator { ...@@ -66,7 +66,7 @@ class ParallelOp : public Operator {
friend class ParallelLoopNestVisitor; friend class ParallelLoopNestVisitor;
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_PARALLEL_H_ #endif // TVM_TL_OP_PARALLEL_H_
...@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr ReduceOp::MakeInitValue() const { PrimExpr ReduceOp::MakeInitValue() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return make_zero(dst->dtype); return make_zero(dst->dtype);
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return make_zero(dst->dtype); return make_zero(dst->dtype);
case ReduceType::kMax: case ReduceType::kMax:
return make_const(dst->dtype, -INFINITY); return make_const(dst->dtype, -INFINITY);
case ReduceType::kMin: case ReduceType::kMin:
return make_const(dst->dtype, INFINITY); return make_const(dst->dtype, INFINITY);
default: default:
ICHECK(0); ICHECK(0);
} }
} }
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const { PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b; PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) { if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs); rhs = Cast(lhs->dtype, rhs);
} }
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return lhs + rhs; return lhs + rhs;
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return lhs + Max(rhs, -rhs); return lhs + Max(rhs, -rhs);
case ReduceType::kMax: case ReduceType::kMax:
return Max(lhs, rhs); return Max(lhs, rhs);
case ReduceType::kMin: case ReduceType::kMin:
return Min(lhs, rhs); return Min(lhs, rhs);
default: default:
ICHECK(0); ICHECK(0);
return PrimExpr(0); return PrimExpr(0);
} }
} }
std::string ReduceOp::MakeCodegenReducer() const { std::string ReduceOp::MakeCodegenReducer() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return "tl::SumOp"; return "tl::SumOp";
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return "tl::SumOp"; return "tl::SumOp";
case ReduceType::kMax: case ReduceType::kMax:
return "tl::MaxOp"; return "tl::MaxOp";
case ReduceType::kMin: case ReduceType::kMin:
return "tl::MinOp"; return "tl::MinOp";
default: default:
ICHECK(0); ICHECK(0);
return ""; return "";
} }
} }
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented."; << "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src]; auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst]; auto dst_buffer = T.buffer_remap[this->dst];
...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array<IterVar> dst_vars; Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) { for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar)); dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
} }
Array<IterVar> src_vars = dst_vars; Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]), src_vars.insert(src_vars.begin() + this->dim,
Var("rv"), IterVarType::kDataPar}); {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
Array<PrimExpr> src_indices = IterVarType::kDataPar});
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> src_indices = src_layout->Forward(
Array<PrimExpr> dst_indices = src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> dst_indices = dst_layout->Forward(
dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts; Array<Stmt> stmts;
// make reduce-init stmt // make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices)); if (this->clear)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce // make thread-local reduce
Array<PrimExpr> src_indice_compressed; Array<PrimExpr> src_indice_compressed;
...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for (size_t i = 0; i < src_layout->OutputDim(); i++) { for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr; PrimExpr expr;
IterVar var; IterVar var;
std::tie(expr, var) = std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer); src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr); src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var); src_var_compressed.push_back(var);
} }
Stmt reduce_local = BufferStore(dst_buffer, Stmt reduce_local = BufferStore(
this->MakeReduce(BufferLoad(dst_buffer, dst_indices), dst_buffer,
BufferLoad(src_buffer, src_indice_compressed)), this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
dst_indices); BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local = reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled, For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}}); ForKind::kUnrolled, reduce_local, NullOpt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
} }
stmts.push_back(reduce_local); stmts.push_back(reduce_local);
// make inter-thread reduce // make inter-thread reduce
PrimExpr src_thread = PrimExpr src_thread = src_layout->ForwardThread(
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {}); src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); auto iter_sum =
for (const auto& iter_split : iter_sum->args) { arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>(); auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined()); ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) { if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale); auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent); auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr); ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue; if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale); int reducing_threads = (*extent) * (*scale);
std::stringstream ss; std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< (*scale) << ">::run"; << reducing_threads << ", " << (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()), Array<PrimExpr> thread_reduce_args = {
BufferLoad(dst_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) { if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype); PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace); thread_reduce_args.push_back(workspace);
} }
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args); auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices)); stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
} }
} }
...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop // make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body); body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body);
} }
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout); body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body; return body;
} }
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict) return {}; if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) { T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value(); auto src_layout = T.layout_map[src].as<Fragment>().value();
...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd.push_back(InputPlaceholder(i - 1)); fwd.push_back(InputPlaceholder(i - 1));
} }
} }
auto thd = auto thd = src_layout->ForwardThread(
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar(); Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
return {{dst, dst_layout}}; return {{dst, dst_layout}};
} }
return {}; return {};
...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(ReduceOp, reduce) TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file \ No newline at end of file
...@@ -18,13 +18,13 @@ namespace tl { ...@@ -18,13 +18,13 @@ namespace tl {
using namespace tir; using namespace tir;
class ReduceOp : public Operator { class ReduceOp : public Operator {
public: public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap); ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
private: private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
enum class ReduceType { enum class ReduceType {
...@@ -36,11 +36,11 @@ class ReduceOp : public Operator { ...@@ -36,11 +36,11 @@ class ReduceOp : public Operator {
bool clear; bool clear;
PrimExpr MakeInitValue() const; PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const; PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const; std::string MakeCodegenReducer() const;
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_REDUCE_H_ #endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file \ No newline at end of file
...@@ -17,12 +17,12 @@ namespace tl { ...@@ -17,12 +17,12 @@ namespace tl {
using namespace runtime; using namespace runtime;
template <typename T> template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
static std::string ArrayToStr(const T* ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", "; if (i > 0)
ss << ", ";
ss << ptr[i]; ss << ptr[i];
} }
ss << "]"; ss << "]";
...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) { ...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
} }
struct TensorMapArgs { struct TensorMapArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5]; cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave; CUtensorMapInterleave interleave;
...@@ -45,8 +45,9 @@ struct TensorMapArgs { ...@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs T; TensorMapArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
...@@ -63,10 +64,14 @@ struct TensorMapArgs { ...@@ -63,10 +64,14 @@ struct TensorMapArgs {
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]); T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
} }
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -79,7 +84,8 @@ struct TensorMapArgs { ...@@ -79,7 +84,8 @@ struct TensorMapArgs {
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -89,23 +95,26 @@ struct TensorMapArgs { ...@@ -89,23 +95,26 @@ struct TensorMapArgs {
}; };
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
TensorMapArgs T = TensorMapArgs::Extract(args); .set_body([](TVMArgs args, TVMRetValue *ret) {
CUresult result = cuTensorMapEncodeTiled( TensorMapArgs T = TensorMapArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim, CUresult result = cuTensorMapEncodeTiled(
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
if (result != CUDA_SUCCESS) { T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl T.swizzle, T.l2Promotion, T.oobFill);
<< T.ToDebugString(); if (result != CUDA_SUCCESS) {
} LOG_FATAL << "Failed to initialize the TMA descriptor " << result
*ret = static_cast<int>(result); << std::endl
}); << T.ToDebugString();
}
*ret = static_cast<int>(result);
});
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5]; cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3]; int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs { ...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs T; TensorMapIm2ColArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs { ...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
} }
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs { ...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl << "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl << "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl << "pixelBoxLowerCorner "
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs { ...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
} }
}; };
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); .set_body([](TVMArgs args, TVMRetValue *ret) {
CUresult result = cuTensorMapEncodeIm2col( TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, CUresult result = cuTensorMapEncodeIm2col(
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
if (result != CUDA_SUCCESS) { T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl T.swizzle, T.l2Promotion, T.oobFill);
<< T.ToDebugString(); if (result != CUDA_SUCCESS) {
} LOG_FATAL << "Failed to initialize the TMA descriptor " << result
*ret = static_cast<int>(result); << std::endl
}); << T.ToDebugString();
}
*ret = static_cast<int>(result);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled"; constexpr const char *tvm_tensormap_create_tiled =
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_tiled";
} // namespace tl constexpr const char *tvm_tensormap_create_im2col =
} // namespace tvm "__tvm_tensormap_create_im2col";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_ #endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file \ No newline at end of file
This diff is collapsed.
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC { class CodeGenTileLangCUDA final : public CodeGenC {
public: public:
CodeGenTileLangCUDA(); CodeGenTileLangCUDA();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory // The size of the barrier array in shared memory
int barrier_count_ = -1; int barrier_count_ = -1;
// whether need mma.h // whether need mma.h
...@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC { ...@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy // Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes; std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts; std::unordered_map<const VarNode *, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, CodeGenTileLangCUDA *p);
std::ostream& os); void PrintWmmaScope(const std::string &scope, DataType t,
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); const VarNode *variable, std::ostream &os);
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_ #endif // TVM_TL_TARGET_CODEGEN_CUDA_H_
This diff is collapsed.
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC { class CodeGenTileLangHIP final : public CodeGenC {
public: public:
CodeGenTileLangHIP(); CodeGenTileLangHIP();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p);
// whether need math_constants.h // whether need math_constants.h
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
...@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC { ...@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_HIP_H_ #endif // TVM_TL_TARGET_CODEGEN_HIP_H_
This diff is collapsed.
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) { static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap; std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second); auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info; runtime::FunctionInfo info;
...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co ...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { if (const auto *f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string(); ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin"; if (ptx[0] != '/')
fmt = "cubin";
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
return String(code); return String(code);
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA); TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug); .set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
.set_body_typed(BuildTLDebug);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
#if defined(__linux__) #if defined(__linux__)
#include <sys/stat.h> #include <sys/stat.h>
#endif #endif
...@@ -8,28 +8,28 @@ ...@@ -8,28 +8,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h" #include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
#define HIPRTC_CALL(x) \
#define HIPRTC_CALL(x) \
\ \
{ \ { \
\ \
hiprtcResult result = x; \ hiprtcResult result = x; \
\ \
if (result != HIPRTC_SUCCESS) { \ if (result != HIPRTC_SUCCESS) { \
\ \
LOG(FATAL) \ LOG(FATAL) \
<< "HiprtcError: " #x " failed with error: " << hiprtcGetErrorString(result); \ << "HiprtcError: " #x " failed with error: " \
<< hiprtcGetErrorString(result); \
\ \
\ \
} \ } \
\ \
\ \
} }
static std::string FindHIPIncludePath() { static std::string FindHIPIncludePath() {
...@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() { ...@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
const std::string delimiter = "/"; const std::string delimiter = "/";
#endif #endif
std::string hip_include_path; std::string hip_include_path;
const char* hip_path_env = std::getenv("HIP_PATH"); const char *hip_path_env = std::getenv("HIP_PATH");
if (hip_path_env != nullptr) { if (hip_path_env != nullptr) {
hip_include_path += hip_path_env; hip_include_path += hip_path_env;
hip_include_path += delimiter + "include"; hip_include_path += delimiter + "include";
...@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() { ...@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
} }
#endif #endif
LOG(FATAL) << "Cannot find HIP include path." LOG(FATAL) << "Cannot find HIP include path."
<< "HIP_PATH is not set or ROCm is not installed in the default installation path." << "HIP_PATH is not set or ROCm is not installed in the default "
"installation path."
<< "In other than linux, it is necessary to set HIP_PATH."; << "In other than linux, it is necessary to set HIP_PATH.";
return hip_include_path; return hip_include_path;
} }
static std::string HIPRTCCompile(const std::string& code, bool include_path = false) { static std::string HIPRTCCompile(const std::string &code,
bool include_path = false) {
std::vector<std::string> compile_params; std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{}; std::vector<const char *> param_cstrings{};
hiprtcProgram prog; hiprtcProgram prog;
std::string cc = "gfx900"; // Default target architecture (can be changed as needed) std::string cc =
"gfx900"; // Default target architecture (can be changed as needed)
int major, minor; int major, minor;
hipError_t e1 = hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, 0); hipError_t e1 = hipDeviceGetAttribute(
hipError_t e2 = hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, 0); &major, hipDeviceAttributeComputeCapabilityMajor, 0);
hipError_t e2 = hipDeviceGetAttribute(
&minor, hipDeviceAttributeComputeCapabilityMinor, 0);
if (e1 == hipSuccess && e2 == hipSuccess) { if (e1 == hipSuccess && e2 == hipSuccess) {
cc = "gfx" + std::to_string(major * 100 + minor * 10); cc = "gfx" + std::to_string(major * 100 + minor * 10);
...@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa ...@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
compile_params.push_back(include_option); compile_params.push_back(include_option);
} }
for (const auto& string : compile_params) { for (const auto &string : compile_params) {
param_cstrings.push_back(string.c_str()); param_cstrings.push_back(string.c_str());
} }
HIPRTC_CALL(hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); HIPRTC_CALL(
hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
hiprtcResult compile_res = hiprtcResult compile_res =
hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
...@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa ...@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
return code_out; return code_out;
} }
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) { static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap; std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second); auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info; runtime::FunctionInfo info;
...@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co ...@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
...@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangHIP: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangHIP: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) { if (const auto *f = Registry::Get("tvm_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_hip_compile")) { if (const auto *f = Registry::Get("tvm_callback_hip_compile")) {
ptx = (*f)(code, target).operator std::string(); ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "hsaco"; if (ptx[0] != '/')
fmt = "hsaco";
} else { } else {
ptx = HIPRTCCompile(code, false); ptx = HIPRTCCompile(code, false);
} }
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_hip").set_body_typed(BuildTileLangHIP); TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
.set_body_typed(BuildTileLangHIP);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -11,13 +11,17 @@ ...@@ -11,13 +11,17 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; } bool TargetIsCuda(Target target) {
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; } return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
return target->GetTargetDeviceType() == kDLROCM;
}
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<String>("arch");
ICHECK(s.defined()); ICHECK(s.defined());
const char* arch_str = s.value().c_str(); const char *arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's'); ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm'); ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_'); ICHECK_EQ(arch_str[2], '_');
...@@ -25,31 +29,36 @@ int GetArchInt(Target target) { ...@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
} }
bool TargetIsVolta(Target target) { bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 70 && arch < 75; return arch >= 70 && arch < 75;
} }
bool TargetIsTuring(Target target) { bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75 && arch < 80; return arch >= 75 && arch < 80;
} }
bool TargetIsAmpere(Target target) { bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 80 && arch < 90; return arch >= 80 && arch < 90;
} }
bool TargetIsHopper(Target target) { bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
bool TargetIsCDNA(Target target) { bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false; if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA // if mcpu start with "gfx9", it is CDNA
...@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) {
return false; return false;
} }
bool TargetHasLdmatrix(Target target) { bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75; return arch >= 75;
} }
bool TargetHasStmatrix(Target target) { bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target); ...@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target);
bool TargetIsAmpere(Target target); bool TargetIsAmpere(Target target);
bool TargetIsHopper(Target target); bool TargetIsHopper(Target target);
bool TargetIsCDNA(Target target); bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target); bool TargetHasStmatrix(Target target);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_UTILS_H_ #endif // TVM_TL_TARGET_UTILS_H_
...@@ -25,56 +25,57 @@ using cutlass::tfloat32_t; ...@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
// Pack two half values. // Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) { TL_DEVICE unsigned __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two half_t values. // Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two bfloat16_t values. // Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
/// Helper to cast SMEM pointer to unsigned /// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const* const ptr) { TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr)); return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t val) { TL_DEVICE void atomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t* val) { TL_DEVICE void atomicAdd(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(*val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAddx2(half_t* address, half_t* val) { TL_DEVICE void atomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2*>(address), static_cast<half2>(*reinterpret_cast<half2*>(val))); atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
TL_DEVICE void atomicAdd(half_t* address, float val) { TL_DEVICE void atomicAdd(half_t *address, float val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), __float2half(val)); atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
} }
// DP4A // DP4A
template<typename InDatatype, typename OutDatatype> template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype* a, InDatatype* b, OutDatatype* c) { TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
const int a_int = *((int*)a); const int a_int = *((int *)a);
const int b_int = *((int*)b); const int b_int = *((int *)b);
const int c_int = *((int*)c); const int c_int = *((int *)c);
*c = __dp4a(a_int, b_int, c_int); *c = __dp4a(a_int, b_int, c_int);
} }
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
namespace tl { namespace tl {
TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); } TL_DEVICE void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int N> template <int N> TL_DEVICE void cp_async_wait() {
TL_DEVICE void cp_async_wait() {
if constexpr (N == 0) { if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::); asm volatile("cp.async.wait_all;\n" ::);
} else { } else {
...@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() { ...@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4); static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr); unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) { if constexpr (N == 16) {
...@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { ...@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;" "cp.async.cg.shared.global [%0], [%1], %2;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N)); "l"((void *)(global_ptr)), "n"(N));
} else { } else {
__asm__ __volatile__( __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH #if TL_ENABLE_L2_PREFETCH
...@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { ...@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;" "cp.async.ca.shared.global [%0], [%1], %2;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N)); "l"((void *)(global_ptr)), "n"(N));
} }
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global_ptr, bool cond) { TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
void *global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4); static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0; int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr); unsigned int addr = smem_ptr_to_uint(smem_addr);
...@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global ...@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.cg.shared.global [%0], [%1], %2, %3;" "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes)); "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} else { } else {
__asm__ __volatile__( __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH #if TL_ENABLE_L2_PREFETCH
...@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global ...@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.ca.shared.global [%0], [%1], %2, %3;" "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes)); "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} }
} }
} // namespace tl } // namespace tl
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment