Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap));
Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr);
}
......@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return nullptr;
}
Var GetVarFromAccessPtr(const PrimExpr& expr) {
Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
......@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false;
if (!is_zero(ranges_[i]->min))
return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
return false;
}
return true;
}
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0);
}
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; }
Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; }
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
} // namespace tl
} // namespace tvm
......@@ -25,17 +25,19 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>;
using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); })
.set_attr<OpBuilderFunc>("TLOpBuilder", \
[](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel {
kFree = 0,
......@@ -64,30 +66,31 @@ struct CanonializeArgs {
};
class Operator {
public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level);
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
};
class RegionOp : public Operator {
public:
public:
RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get();
static const Op &Get();
const Buffer& GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; }
const Buffer &GetBuffer() const { return buffer_; }
const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
private:
private:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
};
Var GetVarFromAccessPtr(const PrimExpr& expr);
Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
......
......@@ -39,21 +39,22 @@ using namespace tir;
namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
}
} // namespace attr
class IfBufferRemapLoopGenerator : public StmtExprMutator {
public:
public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt)));
}
private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map)
private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) {
......@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return load;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
......@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_;
};
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) {
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
......@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
......@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; });
bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
}
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (loop_layout_.defined()) return {};
if (level == InferLevel::kStrict) return {};
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined())
return {};
if (level == InferLevel::kStrict)
return {};
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) {
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer))
......@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer = buffer;
}
}
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) {
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) {
return src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar);
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep);
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
}
};
......@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate();
if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0);
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
......@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} else {
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto* imm = coalesced_width.as<IntImmNode>()) {
if (auto coalesced_width =
root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value;
// Verify that vector_size is divisible by expected
if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width "
<< expected;
LOG(FATAL) << "Vector size " << vector_size
<< " is not divisible by coalesced width " << expected;
}
vector_size = expected;
} else {
......@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size)))
if (!analyzer_.CanProveEqual(loop_thread_extent,
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else {
return {};
}
// Step 2: Check that the loop's partition can correctly align with all source fragment
for (const auto& [buffer, _] : indice_map_) {
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent()))
if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); });
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer
ICHECK(is_zero(diff))
<< "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
}
}
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results;
for (const auto& [buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer));
for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer))
results.Set(buffer, CompleteBufferFragment(buffer));
}
return results;
}
......@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
}
}
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_;
if (IsCommonAccessIndice(buffer))
return loop_layout_;
PrimExpr rep_b =
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
......@@ -242,7 +262,8 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
}
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
......
......@@ -23,30 +23,30 @@ using namespace tir;
class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor {
private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){};
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const BufferStoreNode* op) final;
void VisitExpr_(const BufferLoadNode* op) final;
private:
ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode *op) final;
ParallelOp* p;
ParallelOp *p;
friend class ParallelOp;
};
class ParallelOp : public Operator {
public:
public:
ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const;
private:
Fragment CompleteBufferFragment(const Buffer& buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const;
private:
Fragment CompleteBufferFragment(const Buffer &buffer);
bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
......
......@@ -54,7 +54,7 @@ PrimExpr ReduceOp::MakeInitValue() const {
}
}
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const {
PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs);
......@@ -90,8 +90,9 @@ std::string ReduceOp::MakeCodegenReducer() const {
}
}
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment")
Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst];
......@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar));
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
}
Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]),
Var("rv"), IterVarType::kDataPar});
Array<PrimExpr> src_indices =
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices =
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
src_vars.insert(src_vars.begin() + this->dim,
{Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
IterVarType::kDataPar});
Array<PrimExpr> src_indices = src_layout->Forward(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices = dst_layout->Forward(
dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts;
// make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
if (this->clear)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
......@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr;
IterVar var;
std::tie(expr, var) =
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var);
}
Stmt reduce_local = BufferStore(dst_buffer,
Stmt reduce_local = BufferStore(
dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}});
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
ForKind::kUnrolled, reduce_local, NullOpt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
stmts.push_back(reduce_local);
// make inter-thread reduce
PrimExpr src_thread =
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto& iter_split : iter_sum->args) {
PrimExpr src_thread = src_layout->ForwardThread(
src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum =
arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue;
if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", "
<< (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()),
BufferLoad(dst_buffer, dst_indices)};
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
}
}
......@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body);
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body);
}
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body;
}
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (level >= InferLevel::kStrict) return {};
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
......@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd.push_back(InputPlaceholder(i - 1));
}
}
auto thd =
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar();
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
return {{dst, dst_layout}};
}
return {};
......@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
......@@ -18,13 +18,13 @@ namespace tl {
using namespace tir;
class ReduceOp : public Operator {
public:
public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
private:
private:
tir::Buffer src, dst;
int dim;
enum class ReduceType {
......@@ -36,7 +36,7 @@ class ReduceOp : public Operator {
bool clear;
PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const;
PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const;
};
......
......@@ -17,12 +17,12 @@ namespace tl {
using namespace runtime;
template <typename T>
static std::string ArrayToStr(const T* ptr, size_t n) {
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss;
ss << "[";
for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", ";
if (i > 0)
ss << ", ";
ss << ptr[i];
}
ss << "]";
......@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
}
struct TensorMapArgs {
CUtensorMap* map;
CUtensorMap *map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
void *globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave;
......@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
......@@ -63,10 +64,14 @@ struct TensorMapArgs {
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
}
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
......@@ -79,7 +84,8 @@ struct TensorMapArgs {
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
......@@ -89,23 +95,26 @@ struct TensorMapArgs {
};
// set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
struct TensorMapIm2ColArgs {
CUtensorMap* map;
CUtensorMap *map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
void *globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
......@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
......@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
}
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
T.interleave =
static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle =
static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
......@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "pixelBoxLowerCorner "
<< ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
......@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
}
};
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
.set_body([](TVMArgs args, TVMRetValue *ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1,
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
} // namespace tl
} // namespace tvm
......@@ -13,8 +13,10 @@
namespace tvm {
namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled";
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col";
constexpr const char *tvm_tensormap_create_tiled =
"__tvm_tensormap_create_tiled";
constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col";
} // namespace tl
} // namespace tvm
......
......@@ -6,9 +6,9 @@
*/
#include "codegen_cuda.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
......@@ -23,41 +23,51 @@
namespace tvm {
namespace codegen {
CodeGenTileLangCUDA::CodeGenTileLangCUDA() { restrict_keyword_ = "__restrict__"; }
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
restrict_keyword_ = "__restrict__";
}
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
os << "extern \"C\" __global__ ";
}
class LaunchConfigExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
if (iv->var->name_hint == "threadIdx.x" ||
iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
} else if (iv->var->name_hint == "threadIdx.y" ||
iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
} else if (iv->var->name_hint == "threadIdx.z" ||
iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
public:
public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
LaunchConfigExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
PrimExpr threadIdx_ext =
analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (const IntImmNode *const threadIdx_ext_int =
threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
// unable to extract the number of threads per block, hence directly
// return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
......@@ -77,19 +87,20 @@ std::string CodeGenTileLangCUDA::Finish() {
return CodeGenC::Finish();
}
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min);
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
<< "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
......@@ -97,12 +108,13 @@ void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) {
stream << "}\n";
}
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar& iv) {
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types";
......@@ -153,7 +165,8 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
......@@ -166,8 +179,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
fail = true;
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (t.is_scalar() || t.bits() == 16))
return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
......@@ -181,18 +196,21 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
} else {
fail = true;
}
if (!fail) return;
if (!fail)
return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
if (!fail)
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
......@@ -288,7 +306,8 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
ICHECK_EQ(t.lanes() % 2, 0)
<< "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
......@@ -311,7 +330,8 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
......@@ -348,8 +368,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
PrimExpr lhs, PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
......@@ -383,15 +404,18 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, Pr
os << sret;
}
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
......@@ -401,9 +425,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......@@ -422,20 +448,24 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
}
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
int i, const std::string &value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
......@@ -446,11 +476,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......@@ -469,15 +499,15 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
}
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") {
......@@ -486,8 +516,10 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) {
}
}
void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
std::ostream &os) { // NOLINT(*)
ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
......@@ -496,13 +528,16 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostre
}
}
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
DataType target) {
if (from == target)
return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) {
if (from.is_float16() && (target.is_int() || target.is_uint()) &&
target.bits() == 8) {
os << "(";
if (target.is_uint()) {
os << "u";
......@@ -513,13 +548,14 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, Da
return os.str();
}
void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
if (from_ty.is_scalar())
return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
......@@ -542,8 +578,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret;
}
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) {
//
......@@ -583,7 +621,8 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
std::ostringstream scall;
scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", ";
if (j > 0)
scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
}
scall << ")";
......@@ -592,13 +631,16 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
}
os << sret;
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
os);
}
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
const VarNode* buffer_var = buffer->data.get();
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
std::string scope;
......@@ -654,12 +696,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buff
return os.str();
}
void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent();
this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", ";
if (i > offset)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
......@@ -670,16 +713,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src
<< "+" << src_offset << ");\n";
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< dst_offset << ", " << src << "+" << src_offset << ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset
<< ", " << src << "+" << src_offset << ", " << condition << ");\n";
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
......@@ -691,7 +736,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n";
this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
......@@ -720,13 +766,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async");
......@@ -734,15 +782,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1])
<< ")";
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U);
......@@ -776,7 +825,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
......@@ -831,10 +880,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code =
PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref,
b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
std::string bit_op =
op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_mma_sp())) {
......@@ -872,8 +922,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string sparse_selector = this->PrintExpr(op->args[14]);
bool saturate = Downcast<Bool>(op->args[15])->value;
std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
sparse_selector, "", true, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
......@@ -882,7 +933,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory.
// arg 6: The offset of the start element of the row to load in shared
// memory.
ICHECK_EQ(op->args.size(), 7U);
bool trans = Downcast<Bool>(op->args[0])->value;
int num = Downcast<Integer>(op->args[1])->value;
......@@ -891,20 +943,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
if (trans && op->dtype.bits() == 8) {
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an
// int8 matrix.
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot
// properly transpose an int8 matrix.
std::string smem_stride = this->PrintExpr(op->args[6]);
ICHECK(num == 4);
os << "for (int i = 0; i < 16; ++i) {\n";
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n";
<< "[(i % 8) / 4 * " + smem_stride +
" * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride +
" + threadIdx.x / 4 + (i / 8) * 8];\n";
os << "}\n";
} else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset, smem_ptr,
smem_elem_offset);
}
} else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[0])->value;
......@@ -914,17 +969,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[4]);
PrimExpr stride = op->args[5];
ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now";
ICHECK(m == 16 && n == 16)
<< "Only m == 16 && n == 16 case supported for now";
// Each thread in a warp holds a certain number of elements of an MMA output.
// For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements
// in its registers. So conceptually, a warp memory is organized as a 32x8 block.
// A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below.
// Each thread in a warp holds a certain number of elements of an MMA
// output. For example, if we compute a 16x16 tile using MMA, each thread
// holds 8 elements in its registers. So conceptually, a warp memory is
// organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
// memory is specified by the index map below.
// To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map
// to determine the output location for each 8 element.
// To store the 32x8 output back to a 16x16 tile in shared or global memory,
// we invert this map to determine the output location for each 8 element.
const auto* index_map_func =
const auto *index_map_func =
runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout");
IndexMap index_map;
......@@ -932,10 +989,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
Var i, j;
// The index map is defined as follows:
index_map = IndexMap({i, j}, {
4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2), 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)
});
} else{
index_map = IndexMap(
{i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
} else {
index_map = IndexMap::FromFunc(2, *index_map_func);
}
......@@ -944,20 +1001,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
auto indices_16x16 = inverse_index_map->final_indices;
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine.
// FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them
// to the plain ones here.
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the
// plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
// they reach codegen, so manually replace them to the plain ones here.
class LowerFloorDivMod : public ExprMutator {
public:
PrimExpr VisitExpr_(const FloorDivNode* op) {
PrimExpr VisitExpr_(const FloorDivNode *op) {
return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
}
PrimExpr VisitExpr_(const FloorModNode* op) {
PrimExpr VisitExpr_(const FloorModNode *op) {
return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
}
};
auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
auto dst_ind =
LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
......@@ -967,8 +1025,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
<< " = "
<< "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
os << "}\n";
}
else {
} else {
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
os << dst << "[" + this->PrintExpr(dst_ind) + "]"
<< " = " << src << "[" << src_offset << " + local_id];\n";
......@@ -990,12 +1047,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
need_cast_smem_ptr_to_int_ = true;
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
} else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size,
this->PrintExpr(op->args[5]));
this->stream << PrintPredicatedCpAsyncAssembly(
dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
}
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
need_cast_smem_ptr_to_int_ = true;
......@@ -1006,44 +1065,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string size = this->PrintExpr(op->args[4]);
int barrier_id = Downcast<IntImm>(op->args[5])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier);
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
barrier);
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n";
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
<< ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string thread_count = this->PrintExpr(op->args[1]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string byte_count = this->PrintExpr(op->args[1]);
this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1);
......@@ -1052,13 +1119,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count;
barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
barrier_alignment_count;
}
barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t "
<< barrier_name_ << "[" << barrier_count << "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_
<< "[i] = 0; }\n";
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
<< ") uint64_t " << barrier_name_ << "[" << barrier_count
<< "];\n";
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
<< barrier_name_ << "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) {
/*
asm volatile (
......@@ -1075,7 +1144,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string reg = this->PrintExpr(op->args[0]);
// get guard
std::string guard = this->PrintExpr(op->args[1]);
const BufferLoadNode* addr_buffer = op->args[2].as<BufferLoadNode>();
const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
std::string local_addr = this->PrintExpr(op->args[3]);
......@@ -1087,26 +1156,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
<< ")\n";
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)"
<< guard << ")\n";
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
<< ")), \"r\"((int)" << guard << ")\n";
stream << ");\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
const VarNode *buffer = op->node.as<VarNode>();
const StringImmNode *shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
const VarNode *buffer = op->node.as<VarNode>();
const StringImmNode *layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
......@@ -1114,9 +1184,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
......@@ -1124,7 +1196,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
return;
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>();
const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
......@@ -1133,28 +1205,28 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
CodeGenC::VisitStmt_(op);
}
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>();
const VarNode *buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else{
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}
......@@ -1163,7 +1235,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
if (scope.find("wmma.") == 0) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
}
......@@ -1179,7 +1252,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
this->PrintStmt(op->body);
}
void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
......@@ -1188,16 +1261,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", ";
if (i != lanes - 1)
os << ", ";
}
os << "))";
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) {
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
lanes == 4) {
// make_int8x4
const int64_t* p = as_const_int(op->value);
const int64_t *p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
......@@ -1215,7 +1291,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
os << ')';
......@@ -1228,18 +1305,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) {
if (op->dtype.is_float() && op->dtype.bits() == 32 &&
op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value);
os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
}
os << ')';
......@@ -1248,7 +1328,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
const int64_t *p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
......@@ -1260,7 +1340,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
os << "(int16_t)" << v;
}
} else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
(v << 4) | v;
if (lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
......@@ -1272,7 +1353,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
......@@ -1295,13 +1377,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << v;
}
os << ')';
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p) { // NOLINT(*)
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
......@@ -1322,7 +1406,8 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
if (op->dtype.bits() == 32)
temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
......@@ -1340,16 +1425,18 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
}
}
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t,
const VarNode* variable, std::ostream& os) {
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
const VarNode *variable,
std::ostream &os) {
std::stringstream type;
PrintType(t, type);
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment "
<< variable->name_hint;
ICHECK(fragment_shapes.count(variable))
<< "Cannot find shape of the wmma fragment " << variable->name_hint;
std::string shape_str = fragment_shapes.at(variable);
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
......@@ -1372,23 +1459,24 @@ void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t,
if (scope == "wmma.matrix_a") {
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
<< type.str() << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") {
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
<< type.str() << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") {
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
<< ">";
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
<< ", " << type.str() << ">";
}
}
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
const VarNode *variable,
int32_t size) {
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment "
<< variable->name_hint;
ICHECK(fragment_shapes.count(variable))
<< "Cannot find shape of the wmma fragment " << variable->name_hint;
std::string shape_str = fragment_shapes.at(variable);
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
if (dim.first * dim.second != 0)
......@@ -1397,12 +1485,14 @@ int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const
return 0;
}
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) {
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
const BufferLoadNode *op,
std::ostream &os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
IsVolatile(op->buffer->data.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
......@@ -1411,15 +1501,17 @@ void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const Bu
}
}
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) {
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
const std::string &value,
std::ostream &os) {
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
<< "))";
return;
}
}
......@@ -1476,7 +1568,7 @@ void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::str
return;
}
void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
void CodeGenTileLangCUDA::AddFunction(const PrimFunc &f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
......@@ -1495,10 +1587,11 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (i != 0)
stream << ", ";
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream);
......@@ -1513,8 +1606,8 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
}
CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
......
......@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC {
public:
public:
CodeGenTileLangCUDA();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string &scope,
std::ostream &os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string &vec, DataType t, int i,
std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string &vec, DataType t, int i,
const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
void AddFunction(const PrimFunc &f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// whether need mma.h
......@@ -77,12 +85,14 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p);
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
void PrintWmmaScope(const std::string &scope, DataType t,
const VarNode *variable, std::ostream &os);
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
};
} // namespace codegen
......
......@@ -6,9 +6,9 @@
*/
#include "codegen_hip.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <cmath>
......@@ -28,12 +28,13 @@ namespace codegen {
* \note should use std::format instead when codebase is ported to C++20.
*/
class Replacer {
public:
void register_rule(const std::string& pattern, const std::string& replacement) {
public:
void register_rule(const std::string &pattern,
const std::string &replacement) {
_rules.emplace_back(pattern, replacement);
}
std::string rewrite(std::string str) {
for (auto&& rule : _rules) {
for (auto &&rule : _rules) {
auto [pattern, replacement] = rule;
size_t len = pattern.size();
size_t new_len = replacement.size();
......@@ -47,46 +48,53 @@ class Replacer {
}
void empty_rules() { _rules.clear(); }
private:
private:
std::vector<std::pair<std::string, std::string>> _rules;
};
CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; }
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream &os) {
os << "extern \"C\" __global__ ";
}
class LaunchConfigExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
private:
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
if (iv->var->name_hint == "threadIdx.x" ||
iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
} else if (iv->var->name_hint == "threadIdx.y" ||
iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
} else if (iv->var->name_hint == "threadIdx.z" ||
iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
public:
public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
LaunchConfigExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
PrimExpr threadIdx_ext =
analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (const IntImmNode *const threadIdx_ext_int =
threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
// unable to extract the number of threads per block, hence directly
// return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
......@@ -108,19 +116,20 @@ std::string CodeGenTileLangHIP::Finish() {
return CodeGenC::Finish();
}
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) {
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min);
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
<< "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
......@@ -128,12 +137,13 @@ void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) {
stream << "}\n";
}
void CodeGenTileLangHIP::BindThreadIndex(const IterVar& iv) {
void CodeGenTileLangHIP::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
}
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types";
......@@ -184,7 +194,8 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
......@@ -197,8 +208,10 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
fail = true;
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (t.is_scalar() || t.bits() == 16))
return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
......@@ -212,18 +225,21 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} else {
fail = true;
}
if (!fail) return;
if (!fail)
return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail) return;
if (!fail)
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
......@@ -319,7 +335,8 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
ICHECK_EQ(t.lanes() % 2, 0)
<< "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
......@@ -342,7 +359,8 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
......@@ -379,8 +397,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string &op, DataType t,
PrimExpr lhs, PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result.
std::string sret = name_supply_->FreshName("_");
this->PrintIndent();
......@@ -414,15 +433,18 @@ void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, Pri
os << sret;
}
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t,
int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
......@@ -432,9 +454,11 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......@@ -453,20 +477,24 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
}
void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
int i, const std::string &value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n";
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
......@@ -477,11 +505,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
<< " = " << value << ";\n";
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......@@ -500,15 +528,15 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
}
void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
void CodeGenTileLangHIP::PrintStorageSync(const CallNode *op) {
const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
// DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") {
......@@ -517,8 +545,10 @@ void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) {
}
}
void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
void CodeGenTileLangHIP::PrintStorageScope(const std::string &scope,
std::ostream &os) { // NOLINT(*)
ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
......@@ -527,13 +557,16 @@ void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostrea
}
}
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from,
DataType target) {
if (from == target)
return value;
std::ostringstream os;
os << "((";
this->PrintType(target, os);
os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) {
if (from.is_float16() && (target.is_int() || target.is_uint()) &&
target.bits() == 8) {
os << "(";
if (target.is_uint()) {
os << "u";
......@@ -544,13 +577,14 @@ std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, Dat
return os.str();
}
void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) {
void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) {
DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
if (from_ty.is_scalar())
return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops.
......@@ -573,8 +607,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret;
}
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) {
//
......@@ -614,7 +650,8 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
std::ostringstream scall;
scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", ";
if (j > 0)
scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
}
scall << ")";
......@@ -623,13 +660,16 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
}
os << sret;
} else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
os);
}
}
// Print a reference expression to a buffer.
std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) {
const VarNode* buffer_var = buffer->data.get();
std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::ostringstream os;
std::string vid = GetVarID(buffer_var);
std::string scope;
......@@ -685,12 +725,13 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffe
return os.str();
}
void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent();
this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", ";
if (i > offset)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]);
}
this->stream << ");\n";
......@@ -701,16 +742,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src
<< "+" << src_offset << ");\n";
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< dst_offset << ", " << src << "+" << src_offset << ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset
<< ", " << src << "+" << src_offset << ", " << condition << ");\n";
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
......@@ -722,7 +765,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n";
this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]);
......@@ -751,13 +795,15 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans";
if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async");
......@@ -765,15 +811,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1])
<< ")";
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U);
......@@ -807,7 +854,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[4], os);
os << "], ";
this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value;
} else {
LOG(FATAL) << "Invalid parameters";
......@@ -833,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
}else if (op->op.same_as(builtin::tvm_mfma())) {
} else if (op->op.same_as(builtin::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
......@@ -847,7 +894,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 10: C accumulator
// arg 11: C accumulator index
ICHECK(op->args.size() == 12U) << "Invalid number of arguments for tvm_mfma";
ICHECK(op->args.size() == 12U)
<< "Invalid number of arguments for tvm_mfma";
std::string prefix = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
......@@ -860,7 +908,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
ICHECK(A_layout == "row" || B_layout == "row") << "Matrix core only support row major";
ICHECK(A_layout == "row" || B_layout == "row")
<< "Matrix core only support row major";
// map for dtype -> float32x4 -> float4
std::unordered_map<std::string, std::string> dtype_map = {
{"int8", "char"},
......@@ -873,8 +922,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
{"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"},
{"float32x16", "float32x16"}
};
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}),
......@@ -898,10 +946,11 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
}
}
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
......@@ -909,9 +958,11 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
......@@ -919,7 +970,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
return;
} else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>();
const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
......@@ -928,7 +979,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
CodeGenC::VisitStmt_(op);
}
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
......@@ -941,7 +992,8 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) &&
......@@ -955,7 +1007,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
this->PrintStmt(op->body);
}
void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) {
void CodeGenTileLangHIP::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
......@@ -964,16 +1016,19 @@ void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) {
for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", ";
if (i != lanes - 1)
os << ", ";
}
os << "))";
}
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) {
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
lanes == 4) {
// make_int8x4
const int64_t* p = as_const_int(op->value);
const int64_t *p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
......@@ -991,7 +1046,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
os << ')';
......@@ -1004,18 +1060,21 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
os << ')';
return;
}
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) {
if (op->dtype.is_float() && op->dtype.bits() == 32 &&
op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value);
os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
}
os << ')';
......@@ -1024,7 +1083,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
const int64_t *p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xF;
......@@ -1036,7 +1095,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
os << "(int16_t)" << v;
}
} else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
(v << 4) | v;
if (lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
......@@ -1048,7 +1108,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
......@@ -1071,13 +1132,15 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os);
os << '(';
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
if (i != 0)
os << ", ";
os << v;
}
os << ')';
}
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p) { // NOLINT(*)
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
......@@ -1098,7 +1161,8 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
if (op->dtype.bits() == 32)
temp << 'f';
}
p->MarkConst(temp.str());
os << temp.str();
......@@ -1116,16 +1180,19 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
}
}
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
PrintConst(op, os, this);
}
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) {
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string &value,
const BufferLoadNode *op,
std::ostream &os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile.
//
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) {
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
IsVolatile(op->buffer->data.get())) {
os << "(";
PrintType(op->dtype, os);
os << ")(" << value << ")";
......@@ -1134,15 +1201,17 @@ void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const Buf
}
}
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) {
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i,
const std::string &value,
std::ostream &os) {
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
<< "))";
return;
}
}
......@@ -1199,7 +1268,7 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::stri
return;
}
void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
......@@ -1218,10 +1287,11 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (i != 0)
stream << ", ";
if (v.dtype().is_handle()) {
// work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream);
......@@ -1236,8 +1306,8 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
}
CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
......
......@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC {
public:
public:
CodeGenTileLangHIP();
std::string Finish();
// override behavior
void PrintFuncPrefix(std::ostream& os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
std::string CastFromTo(std::string value, DataType from, DataType target) final;
void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string &scope,
std::ostream &os) final; // NOLINT(*)
void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
PrimExpr rhs,
std::ostream &os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoad(const std::string &vec, DataType t, int i,
std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string &vec, DataType t, int i,
const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
void VisitStmt_(const AllocateNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f);
void AddFunction(const PrimFunc &f);
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
protected:
virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
PrimExpr index) final;
void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private:
private:
// Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
std::ostream& os) final;
void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p);
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p);
// whether need math_constants.h
bool need_math_constants_h_{false};
......
This source diff could not be displayed because it is too large. You can view the blob instead.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace tvm {
namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
......@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
......@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
......@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin";
if (ptx[0] != '/')
fmt = "cubin";
} else {
ICHECK(0);
}
......@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc";
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
......@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string();
}
return String(code);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug);
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
.set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
.set_body_typed(BuildTLDebug);
} // namespace codegen
} // namespace tvm
......@@ -8,13 +8,12 @@
#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
namespace tvm {
namespace codegen {
#define HIPRTC_CALL(x) \
\
{ \
......@@ -24,7 +23,8 @@ namespace codegen {
if (result != HIPRTC_SUCCESS) { \
\
LOG(FATAL) \
<< "HiprtcError: " #x " failed with error: " << hiprtcGetErrorString(result); \
<< "HiprtcError: " #x " failed with error: " \
<< hiprtcGetErrorString(result); \
\
\
} \
......@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
const std::string delimiter = "/";
#endif
std::string hip_include_path;
const char* hip_path_env = std::getenv("HIP_PATH");
const char *hip_path_env = std::getenv("HIP_PATH");
if (hip_path_env != nullptr) {
hip_include_path += hip_path_env;
hip_include_path += delimiter + "include";
......@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
}
#endif
LOG(FATAL) << "Cannot find HIP include path."
<< "HIP_PATH is not set or ROCm is not installed in the default installation path."
<< "HIP_PATH is not set or ROCm is not installed in the default "
"installation path."
<< "In other than linux, it is necessary to set HIP_PATH.";
return hip_include_path;
}
static std::string HIPRTCCompile(const std::string& code, bool include_path = false) {
static std::string HIPRTCCompile(const std::string &code,
bool include_path = false) {
std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{};
std::vector<const char *> param_cstrings{};
hiprtcProgram prog;
std::string cc = "gfx900"; // Default target architecture (can be changed as needed)
std::string cc =
"gfx900"; // Default target architecture (can be changed as needed)
int major, minor;
hipError_t e1 = hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, 0);
hipError_t e2 = hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, 0);
hipError_t e1 = hipDeviceGetAttribute(
&major, hipDeviceAttributeComputeCapabilityMajor, 0);
hipError_t e2 = hipDeviceGetAttribute(
&minor, hipDeviceAttributeComputeCapabilityMinor, 0);
if (e1 == hipSuccess && e2 == hipSuccess) {
cc = "gfx" + std::to_string(major * 100 + minor * 10);
......@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
compile_params.push_back(include_option);
}
for (const auto& string : compile_params) {
for (const auto &string : compile_params) {
param_cstrings.push_back(string.c_str());
}
HIPRTC_CALL(hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
HIPRTC_CALL(
hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
hiprtcResult compile_res =
hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
......@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
return code_out;
}
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) {
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs";
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
......@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype());
}
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) {
for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag);
}
}
......@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
cg.Init(output_ssa);
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangHIP: Can only take PrimFunc";
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangHIP: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
......@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) {
if (const auto *f = Registry::Get("tvm_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string();
}
std::string fmt = "ptx";
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_hip_compile")) {
if (const auto *f = Registry::Get("tvm_callback_hip_compile")) {
ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "hsaco";
if (ptx[0] != '/')
fmt = "hsaco";
} else {
ptx = HIPRTCCompile(code, false);
}
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}
TVM_REGISTER_GLOBAL("target.build.tilelang_hip").set_body_typed(BuildTileLangHIP);
TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
.set_body_typed(BuildTileLangHIP);
} // namespace codegen
} // namespace tvm
......@@ -11,13 +11,17 @@
namespace tvm {
namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; }
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; }
bool TargetIsCuda(Target target) {
return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
return target->GetTargetDeviceType() == kDLROCM;
}
int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch");
ICHECK(s.defined());
const char* arch_str = s.value().c_str();
const char *arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_');
......@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
}
bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 70 && arch < 75;
}
bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 75 && arch < 80;
}
bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 80 && arch < 90;
}
bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false;
if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA
......@@ -78,13 +87,15 @@ bool TargetHasAsyncCopy(Target target) {
return false;
}
bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 75;
}
bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false;
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
......
......@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
/// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const* const ptr) {
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t val) {
TL_DEVICE void atomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(val));
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t* val) {
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(*val));
TL_DEVICE void atomicAdd(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAddx2(half_t* address, half_t* val) {
atomicAdd(reinterpret_cast<half2*>(address), static_cast<half2>(*reinterpret_cast<half2*>(val)));
TL_DEVICE void atomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}
TL_DEVICE void atomicAdd(half_t* address, float val) {
TL_DEVICE void atomicAdd(half_t *address, float val) {
// Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), __float2half(val));
atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
}
// DP4A
template<typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype* a, InDatatype* b, OutDatatype* c) {
const int a_int = *((int*)a);
const int b_int = *((int*)b);
const int c_int = *((int*)c);
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
const int a_int = *((int *)a);
const int b_int = *((int *)b);
const int c_int = *((int *)c);
*c = __dp4a(a_int, b_int, c_int);
}
......@@ -10,10 +10,11 @@
namespace tl {
TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); }
TL_DEVICE void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int N>
TL_DEVICE void cp_async_wait() {
template <int N> TL_DEVICE void cp_async_wait() {
if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::);
} else {
......@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
}
template <int N>
TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
......@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N));
"l"((void *)(global_ptr)), "n"(N));
} else {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
......@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N));
"l"((void *)(global_ptr)), "n"(N));
}
}
template <int N>
TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global_ptr, bool cond) {
TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
void *global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
......@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes));
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} else {
__asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH
......@@ -68,7 +70,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif
::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes));
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
}
}
......
......@@ -8,165 +8,185 @@
namespace tl {
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0) {
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1) {
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2) {
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2, int32_t const& crd3) {
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2),
"r"(crd3)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "memory");
}
TL_DEVICE void tma_load(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& crd0, int32_t const& crd1,
int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) {
TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
"complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), "r"(crd2),
"r"(crd3), "r"(crd4)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
}
TL_DEVICE void tma_load_im2col(const CUtensorMap& descriptor, uint64_t& smem_mbar,
void const* const smem_ptr, int32_t const& coord_c,
int32_t const& coord_w, int32_t const& coord_h,
int32_t const& coord_n, uint16_t const& offset_w,
uint16_t const& offset_h) {
TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
uint64_t &smem_mbar, void const *const smem_ptr,
int32_t const &coord_c, int32_t const &coord_w,
int32_t const &coord_h, int32_t const &coord_n,
uint16_t const &offset_w,
uint16_t const &offset_h) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_mbar = smem_ptr_to_uint(&smem_mbar);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes"
asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
":complete_tx::bytes"
" [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(coord_c), "r"(coord_w),
"r"(coord_h), "r"(coord_n), "h"(offset_w), "h"(offset_h)
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n),
"h"(offset_w), "h"(offset_h)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0) {
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
asm volatile(
"cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1) {
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];"
asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
"{%2, %3}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) {
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];"
asm volatile("cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2)
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2,
int32_t const& crd3) {
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];"
asm volatile("cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3)
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3)
: "memory");
}
TL_DEVICE void tma_store(const CUtensorMap& descriptor, void const* const smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2,
int32_t const& crd3, int32_t const& crd4) {
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0,
int32_t const &crd1, int32_t const &crd2,
int32_t const &crd3, int32_t const &crd4) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile(
"cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];"
asm volatile("cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
"{%2, %3, %4, %5, %6}], [%1];"
:
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4)
: "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1),
"r"(crd2), "r"(crd3), "r"(crd4)
: "memory");
}
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap& descriptor) {
TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) {
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(&descriptor);
asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory");
}
TL_DEVICE void mbarrier_init(uint64_t& smem_barrier, uint32_t arrive_count) {
TL_DEVICE void mbarrier_init(uint64_t &smem_barrier, uint32_t arrive_count) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.init.shared.b64 [%1], %0;" : : "r"(arrive_count), "r"(smem_int_ptr));
asm volatile("mbarrier.init.shared.b64 [%1], %0;"
:
: "r"(arrive_count), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_wait(uint64_t& smem_barrier, int phase_bit) {
TL_DEVICE void mbarrier_wait(uint64_t &smem_barrier, int phase_bit) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile(
"{\n"
asm volatile("{\n"
".reg .pred P1;\n"
"LAB_WAIT:\n"
"mbarrier.try_wait.parity.shared.b64 P1, [%0], %1;\n"
......@@ -175,37 +195,42 @@ TL_DEVICE void mbarrier_wait(uint64_t& smem_barrier, int phase_bit) {
"r"(phase_bit));
}
TL_DEVICE void mbarrier_arrive(uint64_t& smem_barrier) {
TL_DEVICE void mbarrier_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) {
TL_DEVICE void mbarrier_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.expect_tx.shared.b64 [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t& smem_barrier, uint32_t transaction_bytes) {
TL_DEVICE void mbarrier_arrive_expect_tx(uint64_t &smem_barrier,
uint32_t transaction_bytes) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%1], %0;"
:
: "r"(transaction_bytes), "r"(smem_int_ptr));
}
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t& smem_barrier) {
TL_DEVICE void mbarrier_cp_async_arrive(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];" : : "r"(smem_int_ptr));
asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];"
:
: "r"(smem_int_ptr));
}
TL_DEVICE void fence_proxy_async() { asm volatile("fence.proxy.async.shared::cta;" : :); }
TL_DEVICE void fence_proxy_async() {
asm volatile("fence.proxy.async.shared::cta;" : :);
}
TL_DEVICE void syncthreads_partial(uint64_t& smem_barrier) {
TL_DEVICE void syncthreads_partial(uint64_t &smem_barrier) {
uint32_t smem_int_ptr = smem_ptr_to_uint(&smem_barrier);
uint64_t state;
asm volatile(
"{\n"
asm volatile("{\n"
".reg .pred P1;\n"
"mbarrier.arrive.shared.b64 %1, [%0];\n"
"LAB_WAIT:\n"
......@@ -216,14 +241,12 @@ TL_DEVICE void syncthreads_partial(uint64_t& smem_barrier) {
: "r"(smem_int_ptr), "l"(state));
}
template<uint32_t RegCount>
TL_DEVICE void warpgroup_reg_alloc(){
asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template<uint32_t RegCount>
TL_DEVICE void warpgroup_reg_dealloc(){
asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) );
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment