Commit 549416f7 authored by LeiWang1999's avatar LeiWang1999
Browse files

Merge branch 'main' of https://github.com/microsoft/TileLang into main

parents 4d63633a 7fad4e88
...@@ -18,18 +18,18 @@ namespace tl { ...@@ -18,18 +18,18 @@ namespace tl {
using namespace tir; using namespace tir;
class Gemm : public Operator { class Gemm : public Operator {
public: public:
Gemm(Array<PrimExpr> args, BufferMap vmap); Gemm(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
enum class GemmWarpPolicy { enum class GemmWarpPolicy {
kSquare = 0, kSquare = 0,
kFullRow = 1, kFullRow = 1,
kFullCol = 2, kFullCol = 2,
} policy; } policy;
private: private:
std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const; std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const;
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
...@@ -38,11 +38,11 @@ class Gemm : public Operator { ...@@ -38,11 +38,11 @@ class Gemm : public Operator {
int M, N, K; int M, N, K;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions // only will be enabled under cdna mfma instructions
int kPack = 1; int kPack = 1;
bool completed_ = false; bool completed_ = false;
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_GEMM_H_ #endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file \ No newline at end of file
...@@ -20,13 +20,14 @@ using namespace tir; ...@@ -20,13 +20,14 @@ using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region) TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) { std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder"); auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value(); Op op = call->op.as<Op>().value();
if (op_map.count(op)) { if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap)); Operator *ptr = static_cast<Operator *>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr); ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr); return std::unique_ptr<Operator>(ptr);
} }
...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) { ...@@ -41,7 +42,7 @@ std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
return nullptr; return nullptr;
} }
Var GetVarFromAccessPtr(const PrimExpr& expr) { Var GetVarFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
ICHECK(call); ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr())); ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -67,20 +68,27 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
bool RegionOp::IsFullRegion() const { bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) { for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false; if (!is_zero(ranges_[i]->min))
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false; return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i]))
return false;
} }
return true; return true;
} }
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(0) << "Not Implemented Lower method."; ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0); return Evaluate(0);
} }
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; } Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; } LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -25,17 +25,19 @@ using namespace tir; ...@@ -25,17 +25,19 @@ using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>; using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>; using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>; using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>; using OpBuilderFunc = TypedPackedFunc<void *(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \ #define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \ const Op &Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
return op; \ return op; \
} \ } \
TVM_REGISTER_OP("tl." #OpName) \ TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \ .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \ .set_attr<OpBuilderFunc>("TLOpBuilder", \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); }) [](Array<PrimExpr> a, BufferMap b) { \
return (void *)(new Entry(a, b)); \
})
enum class InferLevel { enum class InferLevel {
kFree = 0, kFree = 0,
...@@ -64,35 +66,36 @@ struct CanonializeArgs { ...@@ -64,35 +66,36 @@ struct CanonializeArgs {
}; };
class Operator { class Operator {
public: public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const; virtual Stmt Canonialize(const CanonializeArgs &T,
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level); arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default; virtual ~Operator() = default;
}; };
class RegionOp : public Operator { class RegionOp : public Operator {
public: public:
RegionOp(Array<PrimExpr> args, BufferMap vmap); RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get(); static const Op &Get();
const Buffer& GetBuffer() const { return buffer_; } const Buffer &GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; } const Array<Range> &GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; } int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const; bool IsFullRegion() const;
private: private:
Buffer buffer_; Buffer buffer_;
Array<Range> ranges_; Array<Range> ranges_;
int access_mask_; int access_mask_;
}; };
Var GetVarFromAccessPtr(const PrimExpr& expr); Var GetVarFromAccessPtr(const PrimExpr &expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap); std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_OP_H_ #endif // TVM_TL_OP_OP_H_
...@@ -39,21 +39,22 @@ using namespace tir; ...@@ -39,21 +39,22 @@ using namespace tir;
namespace attr { namespace attr {
/*! \brief Mark that how the loop is vectorized. */ /*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width"; constexpr const char *coalesced_width = "coalesced_width";
} } // namespace attr
class IfBufferRemapLoopGenerator : public StmtExprMutator { class IfBufferRemapLoopGenerator : public StmtExprMutator {
public: public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap, static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) { Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map); IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt))); return Downcast<For>(generator(std::move(stmt)));
} }
private: private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map) IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {} : buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) { if (buffer_remap_.count(load->buffer)) {
...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -65,7 +66,7 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
return load; return load;
} }
Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) { if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices); auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { ...@@ -79,18 +80,20 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator {
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
}; };
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel); ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) { ...@@ -99,11 +102,12 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (op->buffer.scope() == "local.fragment") { if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer); << op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else { } else {
p->indice_map_.Set(op->buffer, op->indices); p->indice_map_.Set(op->buffer, op->indices);
} }
...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) { ...@@ -113,18 +117,20 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const { bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; }); auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice); return StructuralEqual()(indice_map_[buffer], common_indice);
} }
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined()) return {}; if (loop_layout_.defined())
if (level == InferLevel::kStrict) return {}; return {};
if (level == InferLevel::kStrict)
return {};
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value(); auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer)) if (buffer_is_write_.count(buffer))
...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -133,14 +139,16 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
read_source_buffer = buffer; read_source_buffer = buffer;
} }
} }
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) { auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value(); Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return src_layout; return src_layout;
} else { } else {
Var rep; Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep); IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter); return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
} }
}; };
...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -150,12 +158,14 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (read_source_buffer.defined()) { if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated. // Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate(); if (!is_one(loop_layout_->ReplicateExtent()))
loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition // if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) { if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse(); auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd; Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0); for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0)); fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back(); auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0)); AddPredicate(EQ(rep, 0));
...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -163,17 +173,19 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
} else { } else {
// Vectorize Size must be aware of the buffer_remap // Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout // As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined // Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) { if (auto coalesced_width =
if (const auto* imm = coalesced_width.as<IntImmNode>()) { root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto *imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value; int expected = imm->value;
// Verify that vector_size is divisible by expected // Verify that vector_size is divisible by expected
if (vector_size % expected != 0) { if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width " LOG(FATAL) << "Vector size " << vector_size
<< expected; << " is not divisible by coalesced width " << expected;
} }
vector_size = expected; vector_size = expected;
} else { } else {
...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -184,31 +196,37 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size); loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
} }
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size))) if (!analyzer_.CanProveEqual(loop_thread_extent,
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent)); AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else { } else {
return {}; return {};
} }
// Step 2: Check that the loop's partition can correctly align with all source fragment // Step 2: Check that the loop's partition can correctly align with all source
for (const auto& [buffer, _] : indice_map_) { // fragment
for (const auto &[buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) { if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value(); auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases // TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs // need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent())) if (!is_one(loop_layout_->ReplicateExtent()) ||
!is_one(fragment->ReplicateExtent()))
continue; continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); }); auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt); auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt); auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs); auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer ICHECK(is_zero(diff))
<< "\nLHS = " << lhs << "\nRHS = " << rhs; << "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
} }
} }
// Step 3: Infer other fragment's layout from the loop's partition // Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results; LayoutMap results;
for (const auto& [buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer)); if (!T.layout_map.count(buffer))
results.Set(buffer, CompleteBufferFragment(buffer));
} }
return results; return results;
} }
...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const { ...@@ -221,18 +239,20 @@ Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
} }
} }
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) {
ICHECK(loop_layout_.defined()); ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_; if (IsCommonAccessIndice(buffer))
return loop_layout_;
PrimExpr rep_b = PrimExpr rep_b = MakeFlattenedExpression(
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer]; auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b); bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
...@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) { ...@@ -242,11 +262,12 @@ Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
} }
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread( PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); ind_inv->Forward(fwd),
FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt) return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar(); ->CondenseReplicateVar();
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,30 +23,30 @@ using namespace tir; ...@@ -23,30 +23,30 @@ using namespace tir;
class ParallelOp; class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor { class ParallelLoopNestVisitor : public StmtExprVisitor {
private: private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){}; ParallelLoopNestVisitor(ParallelOp *op) : p(op){};
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const BufferStoreNode* op) final; void VisitStmt_(const BufferStoreNode *op) final;
void VisitExpr_(const BufferLoadNode* op) final; void VisitExpr_(const BufferLoadNode *op) final;
ParallelOp* p; ParallelOp *p;
friend class ParallelOp; friend class ParallelOp;
}; };
class ParallelOp : public Operator { class ParallelOp : public Operator {
public: public:
ParallelOp(For root); ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; } Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; } For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; } Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const; Optional<PrimExpr> GetPredicate(Var thread_var) const;
private: private:
Fragment CompleteBufferFragment(const Buffer& buffer); Fragment CompleteBufferFragment(const Buffer &buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const;
void AddPredicate(PrimExpr expr) { void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
...@@ -66,7 +66,7 @@ class ParallelOp : public Operator { ...@@ -66,7 +66,7 @@ class ParallelOp : public Operator {
friend class ParallelLoopNestVisitor; friend class ParallelLoopNestVisitor;
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_PARALLEL_H_ #endif // TVM_TL_OP_PARALLEL_H_
...@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -41,57 +41,58 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
PrimExpr ReduceOp::MakeInitValue() const { PrimExpr ReduceOp::MakeInitValue() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return make_zero(dst->dtype); return make_zero(dst->dtype);
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return make_zero(dst->dtype); return make_zero(dst->dtype);
case ReduceType::kMax: case ReduceType::kMax:
return make_const(dst->dtype, -INFINITY); return make_const(dst->dtype, -INFINITY);
case ReduceType::kMin: case ReduceType::kMin:
return make_const(dst->dtype, INFINITY); return make_const(dst->dtype, INFINITY);
default: default:
ICHECK(0); ICHECK(0);
} }
} }
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const { PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
PrimExpr lhs = a, rhs = b; PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) { if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs); rhs = Cast(lhs->dtype, rhs);
} }
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return lhs + rhs; return lhs + rhs;
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return lhs + Max(rhs, -rhs); return lhs + Max(rhs, -rhs);
case ReduceType::kMax: case ReduceType::kMax:
return Max(lhs, rhs); return Max(lhs, rhs);
case ReduceType::kMin: case ReduceType::kMin:
return Min(lhs, rhs); return Min(lhs, rhs);
default: default:
ICHECK(0); ICHECK(0);
return PrimExpr(0); return PrimExpr(0);
} }
} }
std::string ReduceOp::MakeCodegenReducer() const { std::string ReduceOp::MakeCodegenReducer() const {
switch (type) { switch (type) {
case ReduceType::kSum: case ReduceType::kSum:
return "tl::SumOp"; return "tl::SumOp";
case ReduceType::kAbsSum: case ReduceType::kAbsSum:
return "tl::SumOp"; return "tl::SumOp";
case ReduceType::kMax: case ReduceType::kMax:
return "tl::MaxOp"; return "tl::MaxOp";
case ReduceType::kMin: case ReduceType::kMin:
return "tl::MinOp"; return "tl::MinOp";
default: default:
ICHECK(0); ICHECK(0);
return ""; return "";
} }
} }
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") ICHECK(this->src.scope() == "local.fragment" &&
this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented."; << "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src]; auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst]; auto dst_buffer = T.buffer_remap[this->dst];
...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -101,20 +102,24 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Array<IterVar> dst_vars; Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) { for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar)); dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var,
IterVarType::kDataPar));
} }
Array<IterVar> src_vars = dst_vars; Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]), src_vars.insert(src_vars.begin() + this->dim,
Var("rv"), IterVarType::kDataPar}); {Range(0, src_layout->InputShape()[this->dim]), Var("rv"),
Array<PrimExpr> src_indices = IterVarType::kDataPar});
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> src_indices = src_layout->Forward(
Array<PrimExpr> dst_indices = src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); })); Array<PrimExpr> dst_indices = dst_layout->Forward(
dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts; Array<Stmt> stmts;
// make reduce-init stmt // make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices)); if (this->clear)
stmts.push_back(
BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce // make thread-local reduce
Array<PrimExpr> src_indice_compressed; Array<PrimExpr> src_indice_compressed;
...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -122,45 +127,50 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
for (size_t i = 0; i < src_layout->OutputDim(); i++) { for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr; PrimExpr expr;
IterVar var; IterVar var;
std::tie(expr, var) = std::tie(expr, var) = CompressIterator(src_indices[i], src_vars,
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer); src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr); src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var); src_var_compressed.push_back(var);
} }
Stmt reduce_local = BufferStore(dst_buffer, Stmt reduce_local = BufferStore(
this->MakeReduce(BufferLoad(dst_buffer, dst_indices), dst_buffer,
BufferLoad(src_buffer, src_indice_compressed)), this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
dst_indices); BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) { for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local = reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled, For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}}); ForKind::kUnrolled, reduce_local, NullOpt,
{{tir::attr::pragma_unroll_explicit, Bool(false)}});
} }
stmts.push_back(reduce_local); stmts.push_back(reduce_local);
// make inter-thread reduce // make inter-thread reduce
PrimExpr src_thread = PrimExpr src_thread = src_layout->ForwardThread(
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {}); src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); auto iter_sum =
for (const auto& iter_split : iter_sum->args) { arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto &iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>(); auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined()); ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) { if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale); auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent); auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr); ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue; if (*extent == 1)
continue;
int reducing_threads = (*extent) * (*scale); int reducing_threads = (*extent) * (*scale);
std::stringstream ss; std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< (*scale) << ">::run"; << reducing_threads << ", " << (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()), Array<PrimExpr> thread_reduce_args = {
BufferLoad(dst_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) { if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype); PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace); thread_reduce_args.push_back(workspace);
} }
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args); auto call =
Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices)); stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
} }
} }
...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { ...@@ -170,15 +180,17 @@ Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
// make the outer spatial loop // make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) { for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body); body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent,
ForKind::kParallel, body);
} }
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout); body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body; return body;
} }
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict) return {}; if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) { T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value(); auto src_layout = T.layout_map[src].as<Fragment>().value();
...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -197,10 +209,11 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
fwd.push_back(InputPlaceholder(i - 1)); fwd.push_back(InputPlaceholder(i - 1));
} }
} }
auto thd = auto thd = src_layout->ForwardThread(
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar(); Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
return {{dst, dst_layout}}; return {{dst, dst_layout}};
} }
return {}; return {};
...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) { ...@@ -208,7 +221,8 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
TIR_REGISTER_TL_OP(ReduceOp, reduce) TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4) .set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque)); .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file \ No newline at end of file
...@@ -18,13 +18,13 @@ namespace tl { ...@@ -18,13 +18,13 @@ namespace tl {
using namespace tir; using namespace tir;
class ReduceOp : public Operator { class ReduceOp : public Operator {
public: public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap); ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final; Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op& Get(); static const Op &Get();
private: private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
enum class ReduceType { enum class ReduceType {
...@@ -36,11 +36,11 @@ class ReduceOp : public Operator { ...@@ -36,11 +36,11 @@ class ReduceOp : public Operator {
bool clear; bool clear;
PrimExpr MakeInitValue() const; PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const; PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const;
std::string MakeCodegenReducer() const; std::string MakeCodegenReducer() const;
}; };
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_OP_REDUCE_H_ #endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file \ No newline at end of file
...@@ -17,12 +17,12 @@ namespace tl { ...@@ -17,12 +17,12 @@ namespace tl {
using namespace runtime; using namespace runtime;
template <typename T> template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
static std::string ArrayToStr(const T* ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", "; if (i > 0)
ss << ", ";
ss << ptr[i]; ss << ptr[i];
} }
ss << "]"; ss << "]";
...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) { ...@@ -30,10 +30,10 @@ static std::string ArrayToStr(const T* ptr, size_t n) {
} }
struct TensorMapArgs { struct TensorMapArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5]; cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave; CUtensorMapInterleave interleave;
...@@ -45,8 +45,9 @@ struct TensorMapArgs { ...@@ -45,8 +45,9 @@ struct TensorMapArgs {
TensorMapArgs T; TensorMapArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
...@@ -63,10 +64,14 @@ struct TensorMapArgs { ...@@ -63,10 +64,14 @@ struct TensorMapArgs {
for (size_t i = 0; i < T.tensorRank; i++) { for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]); T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
} }
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -79,7 +84,8 @@ struct TensorMapArgs { ...@@ -79,7 +84,8 @@ struct TensorMapArgs {
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl << "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl << "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -89,23 +95,26 @@ struct TensorMapArgs { ...@@ -89,23 +95,26 @@ struct TensorMapArgs {
}; };
// set device api // set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
TensorMapArgs T = TensorMapArgs::Extract(args); .set_body([](TVMArgs args, TVMRetValue *ret) {
CUresult result = cuTensorMapEncodeTiled( TensorMapArgs T = TensorMapArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim, CUresult result = cuTensorMapEncodeTiled(
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
if (result != CUDA_SUCCESS) { T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl T.swizzle, T.l2Promotion, T.oobFill);
<< T.ToDebugString(); if (result != CUDA_SUCCESS) {
} LOG_FATAL << "Failed to initialize the TMA descriptor " << result
*ret = static_cast<int>(result); << std::endl
}); << T.ToDebugString();
}
*ret = static_cast<int>(result);
});
struct TensorMapIm2ColArgs { struct TensorMapIm2ColArgs {
CUtensorMap* map; CUtensorMap *map;
CUtensorMapDataType type; CUtensorMapDataType type;
cuuint32_t tensorRank; cuuint32_t tensorRank;
void* globalAddress; void *globalAddress;
cuuint64_t globalDim[5], globalStride[5]; cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5]; cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3]; int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs { ...@@ -119,8 +128,9 @@ struct TensorMapIm2ColArgs {
TensorMapIm2ColArgs T; TensorMapIm2ColArgs T;
int idx = 0; int idx = 0;
ICHECK(args.num_args >= 8); ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++])); T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++])); T.type =
static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++])); T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++]; T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5); ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs { ...@@ -142,10 +152,14 @@ struct TensorMapIm2ColArgs {
} }
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]); T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++])); T.interleave =
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++])); T.swizzle =
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++])); static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion =
static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill =
static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T; return T;
} }
...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs { ...@@ -159,9 +173,12 @@ struct TensorMapIm2ColArgs {
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl << "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl << "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl << "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl << "pixelBoxLowerCorner "
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl << "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl << "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl << "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl << "l2Promotion " << l2Promotion << std::endl
...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs { ...@@ -170,18 +187,21 @@ struct TensorMapIm2ColArgs {
} }
}; };
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args); .set_body([](TVMArgs args, TVMRetValue *ret) {
CUresult result = cuTensorMapEncodeIm2col( TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, CUresult result = cuTensorMapEncodeIm2col(
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel, T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill); T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
if (result != CUDA_SUCCESS) { T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl T.swizzle, T.l2Promotion, T.oobFill);
<< T.ToDebugString(); if (result != CUDA_SUCCESS) {
} LOG_FATAL << "Failed to initialize the TMA descriptor " << result
*ret = static_cast<int>(result); << std::endl
}); << T.ToDebugString();
}
*ret = static_cast<int>(result);
});
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
constexpr const char* tvm_tensormap_create_tiled = "__tvm_tensormap_create_tiled"; constexpr const char *tvm_tensormap_create_tiled =
constexpr const char* tvm_tensormap_create_im2col = "__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_tiled";
} // namespace tl constexpr const char *tvm_tensormap_create_im2col =
} // namespace tvm "__tvm_tensormap_create_im2col";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_RUNTIME_H_ #endif // TVM_TL_RUNTIME_RUNTIME_H_
\ No newline at end of file \ No newline at end of file
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
*/ */
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <cmath> #include <cmath>
...@@ -23,41 +23,51 @@ ...@@ -23,41 +23,51 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
CodeGenTileLangCUDA::CodeGenTileLangCUDA() { restrict_keyword_ = "__restrict__"; } CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
restrict_keyword_ = "__restrict__";
}
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; } void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
os << "extern \"C\" __global__ ";
}
class LaunchConfigExtractor : public tir::StmtVisitor { class LaunchConfigExtractor : public tir::StmtVisitor {
private: private:
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { if (iv->var->name_hint == "threadIdx.x" ||
iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value; threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { } else if (iv->var->name_hint == "threadIdx.y" ||
iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value; threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { } else if (iv->var->name_hint == "threadIdx.z" ||
iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value; threadIdx_z_ext = op->value;
} }
} }
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
} }
public: public:
PrimExpr threadIdx_x_ext = Integer(1); PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1); PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1); PrimExpr threadIdx_z_ext = Integer(1);
}; };
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
LaunchConfigExtractor extractor; LaunchConfigExtractor extractor;
extractor(f->body); extractor(f->body);
arith::Analyzer analyzer; arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * PrimExpr threadIdx_ext =
extractor.threadIdx_z_ext); analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) { extractor.threadIdx_z_ext);
if (const IntImmNode *const threadIdx_ext_int =
threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) { if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return // unable to extract the number of threads per block, hence directly
// return
return; return;
} }
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
...@@ -77,19 +87,20 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -77,19 +87,20 @@ std::string CodeGenTileLangCUDA::Finish() {
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) { void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) { if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent(); PrintIndent();
std::string vid = AllocVarID(op->loop_var.get()); std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min); std::string start = PrintExpr(op->min);
stream << "for ("; stream << "for (";
PrintType(op->loop_var.dtype(), stream); PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
<< ") {\n"; << "; ++" << vid << ") {\n";
int for_scope = BeginScope(); int for_scope = BeginScope();
PrintStmt(op->body); PrintStmt(op->body);
this->EndScope(for_scope); this->EndScope(for_scope);
...@@ -97,12 +108,13 @@ void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) { ...@@ -97,12 +108,13 @@ void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode* op) {
stream << "}\n"; stream << "}\n";
} }
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar& iv) { void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get())); ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
} }
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types"; ICHECK(t.is_scalar()) << "do not yet support vector types";
...@@ -123,51 +135,54 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(* ...@@ -123,51 +135,54 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: case 16:
if (t.is_scalar()) { if (t.is_scalar()) {
os << "half_t"; os << "half_t";
} else if (lanes <= 8) { } else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements. // Emit CUDA code to access fp16 vector elements.
// //
// half4 is stored as uint2 // half4 is stored as uint2
// //
// h4.x is emitted as *(half2*)(&(u2.x)).x // h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y // h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x // h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y // h4.w is emitted as *(half2*)(&(u2.y)).y
// //
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2; os << "uint" << lanes / 2;
} else { } else {
fail = true;
}
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true; fail = true;
break; }
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
} }
if (!fail && (t.is_scalar() || t.bits() == 16)) return; if (!fail && (t.is_scalar() || t.bits() == 16))
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
return;
if (!fail && (lanes >= 2 && lanes <= 4)) { if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; os << lanes;
return; return;
...@@ -181,18 +196,21 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(* ...@@ -181,18 +196,21 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
} else { } else {
fail = true; fail = true;
} }
if (!fail) return; if (!fail)
return;
} else if (t.is_float8()) { } else if (t.is_float8()) {
if (t.is_scalar()) { if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) { } else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) { } else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else { } else {
fail = true; fail = true;
} }
if (!fail) return; if (!fail)
return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
os << "bool"; os << "bool";
return; return;
...@@ -209,133 +227,135 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(* ...@@ -209,133 +227,135 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
os << "u"; os << "u";
} }
switch (t.bits()) { switch (t.bits()) {
case 1: { case 1: {
if (t.is_scalar()) { if (t.is_scalar()) {
os << "int"; os << "int";
return; return;
} else if (t.lanes() == 8) { } else if (t.lanes() == 8) {
os << "int8_t"; os << "int8_t";
return; return;
} else if (t.lanes() == 16) { } else if (t.lanes() == 16) {
os << "int16_t"; os << "int16_t";
return; return;
} else if (t.lanes() == 32) { } else if (t.lanes() == 32) {
os << "int"; os << "int";
return; return;
} else { } else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 4) {
os << "int16_t";
return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int";
return;
} else if (t.lanes() == 16) {
os << "int2";
return;
} else if (t.lanes() == 32) {
os << "int4";
return;
} else if (t.lanes() == 64) {
os << "int8";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
} }
case 8: { }
if (t.lanes() == 4) { case 4: {
// directly 4 8 bit int in integer. if (t.is_scalar()) {
os << "int";
// We use int for int8x4 instead of char4 because using char4 is return;
// likely to produce extra instructions to pack four int8 elements } else if (t.lanes() == 4) {
// into 32-bit data. os << "int16_t";
os << "int"; return;
return; } else if (t.lanes() == 8) {
} else if (t.lanes() == 8) { // directly 8 4-bit int in integer.
os << "int2"; os << "int";
return; return;
} else if (t.lanes() == 16) { } else if (t.lanes() == 16) {
os << "int4"; os << "int2";
return; return;
} else if (!t.is_uint() && t.is_scalar()) { } else if (t.lanes() == 32) {
os << "signed char"; os << "int4";
break; return;
} else { } else if (t.lanes() == 64) {
os << "char"; os << "int8";
break; return;
} } else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
} }
case 16: { }
if (t.is_scalar()) { case 8: {
os << "short"; if (t.lanes() == 4) {
} else if (t.lanes() <= 4) { // directly 4 8 bit int in integer.
os << "short" << lanes;
} else if (t.lanes() <= 8) { // We use int for int8x4 instead of char4 because using char4 is
// Emit CUDA code to access int16 vector elements. // likely to produce extra instructions to pack four int8 elements
// // into 32-bit data.
// short4 is stored as int2 os << "int";
// return;
// s4.x is emitted as *(short2*)(&(i2.x)).x } else if (t.lanes() == 8) {
// s4.y is emitted as *(short2*)(&(i2.x)).y os << "int2";
// s4.z is emitted as *(short2*)(&(i2.y)).x return;
// s4.w is emitted as *(short2*)(&(i2.y)).y } else if (t.lanes() == 16) {
// os << "int4";
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; return;
os << "int" << t.lanes() / 2; } else if (!t.is_uint() && t.is_scalar()) {
} else { os << "signed char";
fail = true;
}
if (!fail) {
return;
}
break; break;
} } else {
case 32: { os << "char";
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break; break;
} }
case 64: { }
if (t.is_scalar()) { case 16: {
os << "int64_t"; if (t.is_scalar()) {
} else if (t.lanes() == 2) { os << "short";
os << "longlong2"; } else if (t.lanes() <= 4) {
} else if (t.lanes() == 3) { os << "short" << lanes;
os << "longlong3"; } else if (t.lanes() <= 8) {
} else if (t.lanes() == 4) { // Emit CUDA code to access int16 vector elements.
os << "longlong4"; //
} // short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0)
<< "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return; return;
} }
default: break;
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true; fail = true;
break; }
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
} else if (t.lanes() == 2) {
os << "longlong2";
} else if (t.lanes() == 3) {
os << "longlong3";
} else if (t.lanes() == 4) {
os << "longlong4";
}
return;
}
default:
fail = true;
break;
} }
if (!fail && lanes == 1) { if (!fail && lanes == 1) {
return; return;
...@@ -348,8 +368,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(* ...@@ -348,8 +368,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
} }
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
std::ostream& os) { // NOLINT(*) PrimExpr lhs, PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result. // Declare the result.
std::string sret = name_supply_->FreshName("_"); std::string sret = name_supply_->FreshName("_");
this->PrintIndent(); this->PrintIndent();
...@@ -383,15 +404,18 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, Pr ...@@ -383,15 +404,18 @@ void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string& op, DataType t, Pr
os << sret; os << sret;
} }
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
std::ostream& os) { // NOLINT(*) int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) { if (t.is_scalar()) {
os << vec; os << vec;
return; return;
} }
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char"; std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
...@@ -401,9 +425,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i ...@@ -401,9 +425,11 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -422,20 +448,24 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i ...@@ -422,20 +448,24 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string& vec, DataType t, i
} }
} }
ICHECK(!type_name.empty()); ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2];
} else { } else {
os << vec << "." << access[i]; os << vec << "." << access[i];
} }
} }
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
const std::string& value) { int i, const std::string &value) {
this->PrintIndent(); this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n"; stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else { } else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "="; stream << ac << "=";
...@@ -446,11 +476,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t, ...@@ -446,11 +476,11 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
stream << "(" << value << " << " << i % 4 * 8 << ");\n"; stream << "(" << value << " << " << i % 4 * 8 << ");\n";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< value << ";\n"; << access[i % 2] << " = " << value << ";\n";
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< " = " << value << ";\n"; << access[i % 2] << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -469,15 +499,15 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t, ...@@ -469,15 +499,15 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string& vec, DataType t,
} }
} }
ICHECK(!type_name.empty()); ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< access[i % 2] << " = " << value << ";\n"; << ")))->" << access[i % 2] << " = " << value << ";\n";
} else { } else {
stream << vec << "." << access[i] << " = " << value << ";\n"; stream << vec << "." << access[i] << " = " << value << ";\n";
} }
} }
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) { void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value; const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
// DO nothing. // DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") { } else if (sync == "shared" || sync == "shared.dyn") {
...@@ -486,9 +516,11 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) { ...@@ -486,9 +516,11 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode* op) {
} }
} }
void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " std::ostream &os) { // NOLINT(*)
"all global arrays as input instead"; ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") { if (scope == "shared") {
os << "__shared__ "; os << "__shared__ ";
} else if (scope == "shared.dyn") { } else if (scope == "shared.dyn") {
...@@ -496,13 +528,16 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostre ...@@ -496,13 +528,16 @@ void CodeGenTileLangCUDA::PrintStorageScope(const std::string& scope, std::ostre
} }
} }
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, DataType target) { std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
if (from == target) return value; DataType target) {
if (from == target)
return value;
std::ostringstream os; std::ostringstream os;
os << "(("; os << "((";
this->PrintType(target, os); this->PrintType(target, os);
os << ")"; os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { if (from.is_float16() && (target.is_int() || target.is_uint()) &&
target.bits() == 8) {
os << "("; os << "(";
if (target.is_uint()) { if (target.is_uint()) {
os << "u"; os << "u";
...@@ -513,13 +548,14 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, Da ...@@ -513,13 +548,14 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, Da
return os.str(); return os.str();
} }
void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
DataType from_ty = op->value.dtype(); DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype; DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion. // Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); if (from_ty.is_scalar())
return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks // We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops. // too compact to read. Emit this as vectorized unary ops.
...@@ -542,8 +578,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { ...@@ -542,8 +578,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret; os << sret;
} }
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
bool skip_first_arg, std::ostream& os) { // NOLINT(*) const Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type); DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) { if (ret_dtype.is_vector()) {
// //
...@@ -583,7 +621,8 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c ...@@ -583,7 +621,8 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
std::ostringstream scall; std::ostringstream scall;
scall << global_symbol << "("; scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) { for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", "; if (j > 0)
scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
} }
scall << ")"; scall << ")";
...@@ -592,13 +631,16 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c ...@@ -592,13 +631,16 @@ void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, c
} }
os << sret; os << sret;
} else { } else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
os);
} }
} }
// Print a reference expression to a buffer. // Print a reference expression to a buffer.
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
const VarNode* buffer_var = buffer->data.get(); const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::ostringstream os; std::ostringstream os;
std::string vid = GetVarID(buffer_var); std::string vid = GetVarID(buffer_var);
std::string scope; std::string scope;
...@@ -654,12 +696,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buff ...@@ -654,12 +696,13 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode* buff
return os.str(); return os.str();
} }
void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) { auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent(); this->PrintIndent();
this->stream << name << "("; this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) { for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", "; if (i > offset)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]); this->stream << this->PrintExpr(op->args[i]);
} }
this->stream << ");\n"; this->stream << ");\n";
...@@ -670,16 +713,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -670,16 +713,18 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]); std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async // use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) { if (op->args.size() == 5) {
this->PrintIndent(); this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< "+" << src_offset << ");\n"; << dst_offset << ", " << src << "+" << src_offset << ");\n";
} else { } else {
std::string condition = this->PrintExpr(op->args[5]); std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent(); this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< ", " << src << "+" << src_offset << ", " << condition << ");\n"; << "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
} }
} else if (op->op.same_as(builtin::ptx_commit_group())) { } else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit"); print_extern_call_stmt("tl::cp_async_commit");
...@@ -691,7 +736,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -691,7 +736,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent(); this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value; int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n"; this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) { } else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
...@@ -720,13 +766,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -720,13 +766,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans"; if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) { } else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans"; if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) { } else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async"); print_extern_call_stmt("tl::fence_proxy_async");
...@@ -734,15 +782,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -734,15 +782,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value; int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) { } else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) { } else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U); ICHECK_EQ(op->args.size(), 6U);
...@@ -776,7 +825,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -776,7 +825,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[4], os); this->PrintExpr(op->args[4], os);
os << "], "; os << "], ";
this->PrintExpr(op->args[6], os); this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) { if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value; os << ", nvcuda::wmma::mem_" << str->value;
} else { } else {
LOG(FATAL) << "Invalid parameters"; LOG(FATAL) << "Invalid parameters";
...@@ -831,10 +880,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -831,10 +880,11 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string c_ref = this->PrintExpr(op->args[10]); std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]); std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = Downcast<Bool>(op->args[12])->value; bool saturate = Downcast<Bool>(op->args[12])->value;
std::string bit_op = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : ""; std::string bit_op =
std::string asm_code = op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, std::string asm_code = PrintMMAAssembly(
b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
this->stream << asm_code; this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_mma_sp())) { } else if (op->op.same_as(builtin::ptx_mma_sp())) {
...@@ -872,8 +922,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -872,8 +922,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string sparse_selector = this->PrintExpr(op->args[14]); std::string sparse_selector = this->PrintExpr(op->args[14]);
bool saturate = Downcast<Bool>(op->args[15])->value; bool saturate = Downcast<Bool>(op->args[15])->value;
std::string asm_code = PrintMMAAssembly( std::string asm_code = PrintMMAAssembly(
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
sparse_selector, "", true, saturate);
this->stream << asm_code; this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) { } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not. // arg 0: whether the matrix is loaded in column major format or not.
...@@ -882,7 +933,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -882,7 +933,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: pointer to local buffer. // arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer. // arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load. // arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory. // arg 6: The offset of the start element of the row to load in shared
// memory.
ICHECK_EQ(op->args.size(), 7U); ICHECK_EQ(op->args.size(), 7U);
bool trans = Downcast<Bool>(op->args[0])->value; bool trans = Downcast<Bool>(op->args[0])->value;
int num = Downcast<Integer>(op->args[1])->value; int num = Downcast<Integer>(op->args[1])->value;
...@@ -891,20 +943,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -891,20 +943,23 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string local_elem_offset = this->PrintExpr(op->args[4]); std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]); std::string smem_ptr = this->PrintExpr(op->args[5]);
if (trans && op->dtype.bits() == 8) { if (trans && op->dtype.bits() == 8) {
// Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
// int8 matrix. // properly transpose an int8 matrix.
std::string smem_stride = this->PrintExpr(op->args[6]); std::string smem_stride = this->PrintExpr(op->args[6]);
ICHECK(num == 4); ICHECK(num == 4);
os << "for (int i = 0; i < 16; ++i) {\n"; os << "for (int i = 0; i < 16; ++i) {\n";
os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
<< "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + << "[(i % 8) / 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n"; " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
"+ (i % 4) * " + smem_stride +
" + threadIdx.x / 4 + (i / 8) * 8];\n";
os << "}\n"; os << "}\n";
} else { } else {
std::string smem_elem_offset = this->PrintExpr(op->args[6]); std::string smem_elem_offset = this->PrintExpr(op->args[6]);
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
smem_ptr, smem_elem_offset); local_elem_offset, smem_ptr,
smem_elem_offset);
} }
} else if (op->op.same_as(builtin::mma_store())) { } else if (op->op.same_as(builtin::mma_store())) {
int m = Downcast<Integer>(op->args[0])->value; int m = Downcast<Integer>(op->args[0])->value;
...@@ -914,29 +969,31 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -914,29 +969,31 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[4]); std::string src_offset = this->PrintExpr(op->args[4]);
PrimExpr stride = op->args[5]; PrimExpr stride = op->args[5];
ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; ICHECK(m == 16 && n == 16)
<< "Only m == 16 && n == 16 case supported for now";
// Each thread in a warp holds a certain number of elements of an MMA output. // Each thread in a warp holds a certain number of elements of an MMA
// For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements // output. For example, if we compute a 16x16 tile using MMA, each thread
// in its registers. So conceptually, a warp memory is organized as a 32x8 block. // holds 8 elements in its registers. So conceptually, a warp memory is
// A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
// memory is specified by the index map below.
// To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map // To store the 32x8 output back to a 16x16 tile in shared or global memory,
// to determine the output location for each 8 element. // we invert this map to determine the output location for each 8 element.
const auto* index_map_func = const auto *index_map_func =
runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout"); runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout");
IndexMap index_map; IndexMap index_map;
if (!index_map_func) { if (!index_map_func) {
Var i, j; Var i, j;
// The index map is defined as follows: // The index map is defined as follows:
index_map = IndexMap({i, j}, { index_map = IndexMap(
4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2), 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2) {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
}); 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
} else{ } else {
index_map = IndexMap::FromFunc(2, *index_map_func); index_map = IndexMap::FromFunc(2, *index_map_func);
} }
arith::Analyzer analyzer; arith::Analyzer analyzer;
...@@ -944,20 +1001,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -944,20 +1001,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer); index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
auto indices_16x16 = inverse_index_map->final_indices; auto indices_16x16 = inverse_index_map->final_indices;
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
// FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
// to the plain ones here. // they reach codegen, so manually replace them to the plain ones here.
class LowerFloorDivMod : public ExprMutator { class LowerFloorDivMod : public ExprMutator {
public: public:
PrimExpr VisitExpr_(const FloorDivNode* op) { PrimExpr VisitExpr_(const FloorDivNode *op) {
return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
} }
PrimExpr VisitExpr_(const FloorModNode* op) { PrimExpr VisitExpr_(const FloorModNode *op) {
return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
} }
}; };
auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); auto dst_ind =
LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
...@@ -967,8 +1025,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -967,8 +1025,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
<< " = " << " = "
<< "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n"; << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
os << "}\n"; os << "}\n";
} } else {
else {
os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
os << dst << "[" + this->PrintExpr(dst_ind) + "]" os << dst << "[" + this->PrintExpr(dst_ind) + "]"
<< " = " << src << "[" << src_offset << " + local_id];\n"; << " = " << src << "[" << src_offset << " + local_id];\n";
...@@ -990,12 +1047,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -990,12 +1047,14 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src_offset = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]); std::string size = this->PrintExpr(op->args[4]);
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
// use size of argument list to indicate whether or not to use predicated cp.async // use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) { if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
} else { } else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, this->stream << PrintPredicatedCpAsyncAssembly(
this->PrintExpr(op->args[5])); dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
} }
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
...@@ -1006,44 +1065,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -1006,44 +1065,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string size = this->PrintExpr(op->args[4]); std::string size = this->PrintExpr(op->args[4]);
int barrier_id = Downcast<IntImm>(op->args[5])->value; int barrier_id = Downcast<IntImm>(op->args[5])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
barrier);
} else if (op->op.same_as(builtin::ptx_commit_group())) { } else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) { } else if (op->op.same_as(builtin::ptx_wait_group())) {
int n = Downcast<IntImm>(op->args[0])->value; int n = Downcast<IntImm>(op->args[0])->value;
this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
<< ";\");\n\n";
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintCpAsyncBarrierAsm(barrier); this->stream << PrintCpAsyncBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string thread_count = this->PrintExpr(op->args[1]); std::string thread_count = this->PrintExpr(op->args[1]);
this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintArriveBarrierAsm(barrier); this->stream << PrintArriveBarrierAsm(barrier);
} else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
std::string byte_count = this->PrintExpr(op->args[1]); std::string byte_count = this->PrintExpr(op->args[1]);
this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
} else if (op->op.same_as(builtin::ptx_wait_barrier())) { } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
need_cast_smem_ptr_to_int_ = true; need_cast_smem_ptr_to_int_ = true;
int barrier_id = Downcast<IntImm>(op->args[0])->value; int barrier_id = Downcast<IntImm>(op->args[0])->value;
CHECK(barrier_id < barrier_count_); CHECK(barrier_id < barrier_count_);
std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string barrier =
barrier_name_ + "[" + std::to_string(barrier_id) + "]";
this->stream << PrintWaitBarrierAsm(barrier); this->stream << PrintWaitBarrierAsm(barrier);
} else if (op->op.same_as(builtin::create_barriers())) { } else if (op->op.same_as(builtin::create_barriers())) {
CHECK_EQ(barrier_count_, -1); CHECK_EQ(barrier_count_, -1);
...@@ -1052,13 +1119,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -1052,13 +1119,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0); CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t); int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
if (barrier_count % barrier_alignment_count != 0) { if (barrier_count % barrier_alignment_count != 0) {
barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count; barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
barrier_alignment_count;
} }
barrier_count_ = barrier_count; barrier_count_ = barrier_count;
this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t " this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
<< barrier_name_ << "[" << barrier_count << "];\n"; << ") uint64_t " << barrier_name_ << "[" << barrier_count
this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_ << "];\n";
<< "[i] = 0; }\n"; this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
<< barrier_name_ << "[i] = 0; }\n";
} else if (op->op.same_as(builtin::ptx_ldg32())) { } else if (op->op.same_as(builtin::ptx_ldg32())) {
/* /*
asm volatile ( asm volatile (
...@@ -1075,7 +1144,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -1075,7 +1144,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string reg = this->PrintExpr(op->args[0]); std::string reg = this->PrintExpr(op->args[0]);
// get guard // get guard
std::string guard = this->PrintExpr(op->args[1]); std::string guard = this->PrintExpr(op->args[1]);
const BufferLoadNode* addr_buffer = op->args[2].as<BufferLoadNode>(); const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
std::string global_addr = this->PrintExpr(addr_buffer->indices[0]); std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data); std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
std::string local_addr = this->PrintExpr(op->args[3]); std::string local_addr = this->PrintExpr(op->args[3]);
...@@ -1087,26 +1156,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -1087,26 +1156,27 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
stream << ": \"=f\"(" << reg << "[" << local_addr << "]" stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
<< ")\n"; << ")\n";
stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
<< guard << ")\n"; << ")), \"r\"((int)" << guard << ")\n";
stream << ");\n"; stream << ");\n";
} else { } else {
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
} }
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) { void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::fragment_shape) { if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>(); const VarNode *buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>(); const StringImmNode *shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value; fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == tir::attr::fragment_layout) { } else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>(); const VarNode *buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>(); const StringImmNode *layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value; fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) { } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>(); const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body); this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream); this->VisitExpr(commit_group, this->stream);
...@@ -1114,9 +1184,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) { ...@@ -1114,9 +1184,11 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
} else if (op->attr_key == tir::attr::async_wait_queue_scope) { } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op); auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>(); auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second; auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream); this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>(); auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner); ICHECK(inner);
...@@ -1124,7 +1196,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) { ...@@ -1124,7 +1196,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
return; return;
} else if (op->attr_key == "threadblock_swizzle_pattern") { } else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent(); this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>(); const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern); ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body); this->VisitStmt(op->body);
...@@ -1133,28 +1205,28 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) { ...@@ -1133,28 +1205,28 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode* op) {
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) { void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition)); ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
this->PrintIndent(); this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var); std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode* buffer = op->buffer_var.as<VarNode>(); const VarNode *buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) { if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::BFloat(16)) op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char " << "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now"; << "or uint4 or int4 or int1 type for now";
} else { } else {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || ICHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(32)) op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now"; << "Accumulator only support half, float and int type for now";
} }
PrintWmmaScope(scope, op->dtype, buffer, stream); PrintWmmaScope(scope, op->dtype, buffer, stream);
} else{ } else {
PrintStorageScope(scope, stream); PrintStorageScope(scope, stream);
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
} }
...@@ -1163,7 +1235,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) { ...@@ -1163,7 +1235,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
stream << ' ' << vid << "[];\n"; stream << ' ' << vid << "[];\n";
} else { } else {
size_t constant_size = op->ConstantAllocationSize(); size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
if (scope.find("wmma.") == 0) { if (scope.find("wmma.") == 0) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
} }
...@@ -1179,7 +1252,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) { ...@@ -1179,7 +1252,7 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode* op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_"; os << "(make_";
...@@ -1188,16 +1261,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { ...@@ -1188,16 +1261,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
for (int i = 0; i < lanes; i++) { for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")"; << "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", "; if (i != lanes - 1)
os << ", ";
} }
os << "))"; os << "))";
} }
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
lanes == 4) {
// make_int8x4 // make_int8x4
const int64_t* p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
ICHECK(p); ICHECK(p);
int64_t v = *p & 0xFF; int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v; v = (v << 24) | (v << 16) | (v << 8) | v;
...@@ -1215,7 +1291,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1215,7 +1291,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "__pack_half2(" << v << ", " << v << ")"; os << "__pack_half2(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
...@@ -1228,18 +1305,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1228,18 +1305,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
return; return;
} }
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) { if (op->dtype.is_float() && op->dtype.bits() == 32 &&
op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << "make_ulonglong4("; os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
...@@ -1248,7 +1328,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1248,7 +1328,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false; bool fail = false;
const int64_t* p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
ICHECK(p); ICHECK(p);
int64_t v = *p & 0xF; int64_t v = *p & 0xF;
...@@ -1260,7 +1340,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1260,7 +1340,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
os << "(int16_t)" << v; os << "(int16_t)" << v;
} }
} else { } else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
(v << 4) | v;
if (lanes == 8) { if (lanes == 8) {
if (op->dtype.is_uint()) { if (op->dtype.is_uint()) {
os << "(uint)" << v; os << "(uint)" << v;
...@@ -1272,7 +1353,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1272,7 +1353,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 8; ++i) { for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
if (op->dtype.is_uint()) { if (op->dtype.is_uint()) {
os << "(uint)" << v; os << "(uint)" << v;
} else { } else {
...@@ -1295,13 +1377,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) ...@@ -1295,13 +1377,15 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os)
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes; ++i) { for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << v; os << v;
} }
os << ')'; os << ')';
} }
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p) { // NOLINT(*) inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p) { // NOLINT(*)
// Type code is kBFloat // Type code is kBFloat
if (op->dtype.is_bfloat16()) { if (op->dtype.is_bfloat16()) {
os << "bfloat16_t"; os << "bfloat16_t";
...@@ -1310,46 +1394,49 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang ...@@ -1310,46 +1394,49 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
} }
// Type code is kFloat // Type code is kFloat
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
case 64: case 64:
case 32: { case 32: {
std::ostringstream temp; std::ostringstream temp;
if (std::isinf(op->value)) { if (std::isinf(op->value)) {
if (op->value < 0) { if (op->value < 0) {
temp << "-"; temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
} }
p->MarkConst(temp.str()); temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
os << temp.str(); } else if (std::isnan(op->value)) {
break; temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
} } else {
case 16: { temp << std::scientific << op->value;
os << "half_t" << '('; if (op->dtype.bits() == 32)
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); temp << 'f';
PrintConst(const_f32.get(), os, p);
os << ')';
break;
} }
default: p->MarkConst(temp.str());
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
} }
} }
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
PrintConst(op, os, this); PrintConst(op, os, this);
} }
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t, void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
const VarNode* variable, std::ostream& os) { const VarNode *variable,
std::ostream &os) {
std::stringstream type; std::stringstream type;
PrintType(t, type); PrintType(t, type);
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " ICHECK(fragment_shapes.count(variable))
<< variable->name_hint; << "Cannot find shape of the wmma fragment " << variable->name_hint;
std::string shape_str = fragment_shapes.at(variable); std::string shape_str = fragment_shapes.at(variable);
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string()); type.str(std::string());
...@@ -1372,23 +1459,24 @@ void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t, ...@@ -1372,23 +1459,24 @@ void CodeGenTileLangCUDA::PrintWmmaScope(const std::string& scope, DataType t,
if (scope == "wmma.matrix_a") { if (scope == "wmma.matrix_a") {
std::string layout_str = fragment_layouts[variable]; std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str() os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
<< ", nvcuda::wmma::" << layout_str << ">"; << type.str() << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") { } else if (scope == "wmma.matrix_b") {
std::string layout_str = fragment_layouts[variable]; std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str() os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
<< ", nvcuda::wmma::" << layout_str << ">"; << type.str() << ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") { } else if (scope == "wmma.accumulator") {
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str() os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
<< ">"; << ", " << type.str() << ">";
} }
} }
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
const VarNode *variable,
int32_t size) { int32_t size) {
ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " ICHECK(fragment_shapes.count(variable))
<< variable->name_hint; << "Cannot find shape of the wmma fragment " << variable->name_hint;
std::string shape_str = fragment_shapes.at(variable); std::string shape_str = fragment_shapes.at(variable);
std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope); std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
if (dim.first * dim.second != 0) if (dim.first * dim.second != 0)
...@@ -1397,12 +1485,14 @@ int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const ...@@ -1397,12 +1485,14 @@ int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string& scope, const
return 0; return 0;
} }
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
std::ostream& os) { const BufferLoadNode *op,
std::ostream &os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and // Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile. // stores are volatile. The loaded objects are not marked as volatile.
// //
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
IsVolatile(op->buffer->data.get())) {
os << "("; os << "(";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << ")(" << value << ")"; os << ")(" << value << ")";
...@@ -1411,15 +1501,17 @@ void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const Bu ...@@ -1411,15 +1501,17 @@ void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string& value, const Bu
} }
} }
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
std::ostream& os) { const std::string &value,
std::ostream &os) {
ICHECK_GT(t.lanes(), 1); ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) { if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) { if (i != 0) {
os << "|"; os << "|";
} }
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
<< "))";
return; return;
} }
} }
...@@ -1476,7 +1568,7 @@ void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::str ...@@ -1476,7 +1568,7 @@ void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::str
return; return;
} }
void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) { void CodeGenTileLangCUDA::AddFunction(const PrimFunc &f) {
// clear previous generated state. // clear previous generated state.
this->InitFuncState(f); this->InitFuncState(f);
// reserve keywords // reserve keywords
...@@ -1495,10 +1587,11 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) { ...@@ -1495,10 +1587,11 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i]; tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", "; if (i != 0)
stream << ", ";
if (v.dtype().is_handle()) { if (v.dtype().is_handle()) {
// work around for grid constant parameters. // work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") { if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const "; stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream); CodeGenC::PrintType(ptr->element_type, stream);
...@@ -1513,8 +1606,8 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) { ...@@ -1513,8 +1606,8 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
} }
CodeGenC::PrintType(GetType(v), stream); CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype); RegisterHandleType(v.get(), prim->dtype);
} }
} }
...@@ -1536,5 +1629,5 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) { ...@@ -1536,5 +1629,5 @@ void CodeGenTileLangCUDA::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n"; this->stream << "}\n\n";
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangCUDA final : public CodeGenC { class CodeGenTileLangCUDA final : public CodeGenC {
public: public:
CodeGenTileLangCUDA(); CodeGenTileLangCUDA();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory // The size of the barrier array in shared memory
int barrier_count_ = -1; int barrier_count_ = -1;
// whether need mma.h // whether need mma.h
...@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC { ...@@ -77,15 +85,17 @@ class CodeGenTileLangCUDA final : public CodeGenC {
// Set to 16 to maintain minimum alignment requirements for async bulk copy // Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
std::unordered_map<const VarNode*, std::string> fragment_shapes; std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts; std::unordered_map<const VarNode *, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangCUDA* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, CodeGenTileLangCUDA *p);
std::ostream& os); void PrintWmmaScope(const std::string &scope, DataType t,
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); const VarNode *variable, std::ostream &os);
int32_t GetWmmaFragmentSize(const std::string &scope, const VarNode *variable,
int32_t size);
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_CUDA_H_ #endif // TVM_TL_TARGET_CODEGEN_CUDA_H_
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
*/ */
#include "codegen_hip.h" #include "codegen_hip.h"
#include <tvm/tir/index_map.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <cmath> #include <cmath>
...@@ -28,12 +28,13 @@ namespace codegen { ...@@ -28,12 +28,13 @@ namespace codegen {
* \note should use std::format instead when codebase is ported to C++20. * \note should use std::format instead when codebase is ported to C++20.
*/ */
class Replacer { class Replacer {
public: public:
void register_rule(const std::string& pattern, const std::string& replacement) { void register_rule(const std::string &pattern,
const std::string &replacement) {
_rules.emplace_back(pattern, replacement); _rules.emplace_back(pattern, replacement);
} }
std::string rewrite(std::string str) { std::string rewrite(std::string str) {
for (auto&& rule : _rules) { for (auto &&rule : _rules) {
auto [pattern, replacement] = rule; auto [pattern, replacement] = rule;
size_t len = pattern.size(); size_t len = pattern.size();
size_t new_len = replacement.size(); size_t new_len = replacement.size();
...@@ -47,46 +48,53 @@ class Replacer { ...@@ -47,46 +48,53 @@ class Replacer {
} }
void empty_rules() { _rules.clear(); } void empty_rules() { _rules.clear(); }
private: private:
std::vector<std::pair<std::string, std::string>> _rules; std::vector<std::pair<std::string, std::string>> _rules;
}; };
CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; } CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; }
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; } void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream &os) {
os << "extern \"C\" __global__ ";
}
class LaunchConfigExtractor : public tir::StmtVisitor { class LaunchConfigExtractor : public tir::StmtVisitor {
private: private:
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { if (iv->var->name_hint == "threadIdx.x" ||
iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value; threadIdx_x_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { } else if (iv->var->name_hint == "threadIdx.y" ||
iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value; threadIdx_y_ext = op->value;
} else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { } else if (iv->var->name_hint == "threadIdx.z" ||
iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value; threadIdx_z_ext = op->value;
} }
} }
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
} }
public: public:
PrimExpr threadIdx_x_ext = Integer(1); PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1); PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1); PrimExpr threadIdx_z_ext = Integer(1);
}; };
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
LaunchConfigExtractor extractor; LaunchConfigExtractor extractor;
extractor(f->body); extractor(f->body);
arith::Analyzer analyzer; arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * PrimExpr threadIdx_ext =
extractor.threadIdx_z_ext); analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) { extractor.threadIdx_z_ext);
if (const IntImmNode *const threadIdx_ext_int =
threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) { if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return // unable to extract the number of threads per block, hence directly
// return
return; return;
} }
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
...@@ -108,19 +116,20 @@ std::string CodeGenTileLangHIP::Finish() { ...@@ -108,19 +116,20 @@ std::string CodeGenTileLangHIP::Finish() {
return CodeGenC::Finish(); return CodeGenC::Finish();
} }
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) { void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) { if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent(); PrintIndent();
stream << "#pragma unroll\n"; stream << "#pragma unroll\n";
} }
std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
PrintIndent(); PrintIndent();
std::string vid = AllocVarID(op->loop_var.get()); std::string vid = AllocVarID(op->loop_var.get());
std::string start = PrintExpr(op->min); std::string start = PrintExpr(op->min);
stream << "for ("; stream << "for (";
PrintType(op->loop_var.dtype(), stream); PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
<< ") {\n"; << "; ++" << vid << ") {\n";
int for_scope = BeginScope(); int for_scope = BeginScope();
PrintStmt(op->body); PrintStmt(op->body);
this->EndScope(for_scope); this->EndScope(for_scope);
...@@ -128,12 +137,13 @@ void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) { ...@@ -128,12 +137,13 @@ void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode* op) {
stream << "}\n"; stream << "}\n";
} }
void CodeGenTileLangHIP::BindThreadIndex(const IterVar& iv) { void CodeGenTileLangHIP::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get())); ICHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
} }
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
int lanes = t.lanes(); int lanes = t.lanes();
if (t.is_handle()) { if (t.is_handle()) {
ICHECK(t.is_scalar()) << "do not yet support vector types"; ICHECK(t.is_scalar()) << "do not yet support vector types";
...@@ -154,51 +164,54 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -154,51 +164,54 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
bool fail = false; bool fail = false;
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: case 16:
if (t.is_scalar()) { if (t.is_scalar()) {
os << "half_t"; os << "half_t";
} else if (lanes <= 8) { } else if (lanes <= 8) {
// Emit CUDA code to access fp16 vector elements. // Emit CUDA code to access fp16 vector elements.
// //
// half4 is stored as uint2 // half4 is stored as uint2
// //
// h4.x is emitted as *(half2*)(&(u2.x)).x // h4.x is emitted as *(half2*)(&(u2.x)).x
// h4.y is emitted as *(half2*)(&(u2.x)).y // h4.y is emitted as *(half2*)(&(u2.x)).y
// h4.z is emitted as *(half2*)(&(u2.y)).x // h4.z is emitted as *(half2*)(&(u2.y)).x
// h4.w is emitted as *(half2*)(&(u2.y)).y // h4.w is emitted as *(half2*)(&(u2.y)).y
// //
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2; os << "uint" << lanes / 2;
} else { } else {
fail = true;
}
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true; fail = true;
break; }
break;
case 32:
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
break;
default:
fail = true;
break;
} }
if (!fail && (t.is_scalar() || t.bits() == 16)) return; if (!fail && (t.is_scalar() || t.bits() == 16))
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
return;
if (!fail && (lanes >= 2 && lanes <= 4)) { if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; os << lanes;
return; return;
...@@ -212,18 +225,21 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -212,18 +225,21 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} else { } else {
fail = true; fail = true;
} }
if (!fail) return; if (!fail)
return;
} else if (t.is_float8()) { } else if (t.is_float8()) {
if (t.is_scalar()) { if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) { } else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of unsigned short os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) { } else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else { } else {
fail = true; fail = true;
} }
if (!fail) return; if (!fail)
return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
os << "bool"; os << "bool";
return; return;
...@@ -240,133 +256,135 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -240,133 +256,135 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << "u"; os << "u";
} }
switch (t.bits()) { switch (t.bits()) {
case 1: { case 1: {
if (t.is_scalar()) { if (t.is_scalar()) {
os << "int"; os << "int";
return; return;
} else if (t.lanes() == 8) { } else if (t.lanes() == 8) {
os << "int8_t"; os << "int8_t";
return; return;
} else if (t.lanes() == 16) { } else if (t.lanes() == 16) {
os << "int16_t"; os << "int16_t";
return; return;
} else if (t.lanes() == 32) { } else if (t.lanes() == 32) {
os << "int"; os << "int";
return; return;
} else { } else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.is_scalar()) {
os << "int";
return;
} else if (t.lanes() == 4) {
os << "int16_t";
return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int";
return;
} else if (t.lanes() == 16) {
os << "int2";
return;
} else if (t.lanes() == 32) {
os << "int4";
return;
} else if (t.lanes() == 64) {
os << "int8";
return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
} }
case 8: { }
if (t.lanes() == 4) { case 4: {
// directly 4 8 bit int in integer. if (t.is_scalar()) {
os << "int";
// We use int for int8x4 instead of char4 because using char4 is return;
// likely to produce extra instructions to pack four int8 elements } else if (t.lanes() == 4) {
// into 32-bit data. os << "int16_t";
os << "int"; return;
return; } else if (t.lanes() == 8) {
} else if (t.lanes() == 8) { // directly 8 4-bit int in integer.
os << "int2"; os << "int";
return; return;
} else if (t.lanes() == 16) { } else if (t.lanes() == 16) {
os << "int4"; os << "int2";
return; return;
} else if (!t.is_uint() && t.is_scalar()) { } else if (t.lanes() == 32) {
os << "signed char"; os << "int4";
break; return;
} else { } else if (t.lanes() == 64) {
os << "char"; os << "int8";
break; return;
} } else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
} }
case 16: { }
if (t.is_scalar()) { case 8: {
os << "short"; if (t.lanes() == 4) {
} else if (t.lanes() <= 4) { // directly 4 8 bit int in integer.
os << "short" << lanes;
} else if (t.lanes() <= 8) { // We use int for int8x4 instead of char4 because using char4 is
// Emit CUDA code to access int16 vector elements. // likely to produce extra instructions to pack four int8 elements
// // into 32-bit data.
// short4 is stored as int2 os << "int";
// return;
// s4.x is emitted as *(short2*)(&(i2.x)).x } else if (t.lanes() == 8) {
// s4.y is emitted as *(short2*)(&(i2.x)).y os << "int2";
// s4.z is emitted as *(short2*)(&(i2.y)).x return;
// s4.w is emitted as *(short2*)(&(i2.y)).y } else if (t.lanes() == 16) {
// os << "int4";
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; return;
os << "int" << t.lanes() / 2; } else if (!t.is_uint() && t.is_scalar()) {
} else { os << "signed char";
fail = true;
}
if (!fail) {
return;
}
break; break;
} } else {
case 32: { os << "char";
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break; break;
} }
case 64: { }
if (t.is_scalar()) { case 16: {
os << "int64_t"; if (t.is_scalar()) {
} else if (t.lanes() == 2) { os << "short";
os << "longlong2"; } else if (t.lanes() <= 4) {
} else if (t.lanes() == 3) { os << "short" << lanes;
os << "longlong3"; } else if (t.lanes() <= 8) {
} else if (t.lanes() == 4) { // Emit CUDA code to access int16 vector elements.
os << "longlong4"; //
} // short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0)
<< "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return; return;
} }
default: break;
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0)
<< "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true; fail = true;
break; }
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
} else if (t.lanes() == 2) {
os << "longlong2";
} else if (t.lanes() == 3) {
os << "longlong3";
} else if (t.lanes() == 4) {
os << "longlong4";
}
return;
}
default:
fail = true;
break;
} }
if (!fail && lanes == 1) { if (!fail && lanes == 1) {
return; return;
...@@ -379,8 +397,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -379,8 +397,9 @@ void CodeGenTileLangHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
} }
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string &op, DataType t,
std::ostream& os) { // NOLINT(*) PrimExpr lhs, PrimExpr rhs,
std::ostream &os) { // NOLINT(*)
// Declare the result. // Declare the result.
std::string sret = name_supply_->FreshName("_"); std::string sret = name_supply_->FreshName("_");
this->PrintIndent(); this->PrintIndent();
...@@ -414,15 +433,18 @@ void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, Pri ...@@ -414,15 +433,18 @@ void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string& op, DataType t, Pri
os << sret; os << sret;
} }
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, int i, void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t,
std::ostream& os) { // NOLINT(*) int i,
std::ostream &os) { // NOLINT(*)
if (t.is_scalar()) { if (t.is_scalar()) {
os << vec; os << vec;
return; return;
} }
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char"; std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
...@@ -432,9 +454,11 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in ...@@ -432,9 +454,11 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -453,20 +477,24 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in ...@@ -453,20 +477,24 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string& vec, DataType t, in
} }
} }
ICHECK(!type_name.empty()); ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2];
} else { } else {
os << vec << "." << access[i]; os << vec << "." << access[i];
} }
} }
void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, int i, void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
const std::string& value) { int i, const std::string &value) {
this->PrintIndent(); this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); ICHECK(i >= 0 && i < (t.bits() == 8 ? 16
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n"; stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else { } else {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "="; stream << ac << "=";
...@@ -477,11 +505,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i ...@@ -477,11 +505,11 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
stream << "(" << value << " << " << i % 4 * 8 << ");\n"; stream << "(" << value << " << " << i % 4 * 8 << ");\n";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< value << ";\n"; << access[i % 2] << " = " << value << ";\n";
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< " = " << value << ";\n"; << access[i % 2] << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -500,15 +528,15 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i ...@@ -500,15 +528,15 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string& vec, DataType t, i
} }
} }
ICHECK(!type_name.empty()); ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< access[i % 2] << " = " << value << ";\n"; << ")))->" << access[i % 2] << " = " << value << ";\n";
} else { } else {
stream << vec << "." << access[i] << " = " << value << ";\n"; stream << vec << "." << access[i] << " = " << value << ";\n";
} }
} }
void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) { void CodeGenTileLangHIP::PrintStorageSync(const CallNode *op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value; const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") { if (sync == "warp") {
// DO nothing. // DO nothing.
} else if (sync == "shared" || sync == "shared.dyn") { } else if (sync == "shared" || sync == "shared.dyn") {
...@@ -517,9 +545,11 @@ void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) { ...@@ -517,9 +545,11 @@ void CodeGenTileLangHIP::PrintStorageSync(const CallNode* op) {
} }
} }
void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) void CodeGenTileLangHIP::PrintStorageScope(const std::string &scope,
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " std::ostream &os) { // NOLINT(*)
"all global arrays as input instead"; ICHECK_NE(scope, "global")
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") { if (scope == "shared") {
os << "__shared__ "; os << "__shared__ ";
} else if (scope == "shared.dyn") { } else if (scope == "shared.dyn") {
...@@ -527,13 +557,16 @@ void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostrea ...@@ -527,13 +557,16 @@ void CodeGenTileLangHIP::PrintStorageScope(const std::string& scope, std::ostrea
} }
} }
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, DataType target) { std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from,
if (from == target) return value; DataType target) {
if (from == target)
return value;
std::ostringstream os; std::ostringstream os;
os << "(("; os << "((";
this->PrintType(target, os); this->PrintType(target, os);
os << ")"; os << ")";
if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { if (from.is_float16() && (target.is_int() || target.is_uint()) &&
target.bits() == 8) {
os << "("; os << "(";
if (target.is_uint()) { if (target.is_uint()) {
os << "u"; os << "u";
...@@ -544,13 +577,14 @@ std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, Dat ...@@ -544,13 +577,14 @@ std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from, Dat
return os.str(); return os.str();
} }
void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) { void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) {
DataType from_ty = op->value.dtype(); DataType from_ty = op->value.dtype();
DataType target_ty = op->dtype; DataType target_ty = op->dtype;
ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); ICHECK_EQ(target_ty.lanes(), from_ty.lanes());
// Emit simple C-style type conversion. // Emit simple C-style type conversion.
if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); if (from_ty.is_scalar())
return CodeGenC::VisitExpr_(op, os);
// We could emit make_float4 like calls, but the emitted code looks // We could emit make_float4 like calls, but the emitted code looks
// too compact to read. Emit this as vectorized unary ops. // too compact to read. Emit this as vectorized unary ops.
...@@ -573,8 +607,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) { ...@@ -573,8 +607,10 @@ void CodeGenTileLangHIP::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret; os << sret;
} }
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol,
bool skip_first_arg, std::ostream& os) { // NOLINT(*) const Array<PrimExpr> &args,
bool skip_first_arg,
std::ostream &os) { // NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type); DataType ret_dtype = GetRuntimeDataType(ret_type);
if (ret_dtype.is_vector()) { if (ret_dtype.is_vector()) {
// //
...@@ -614,7 +650,8 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co ...@@ -614,7 +650,8 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
std::ostringstream scall; std::ostringstream scall;
scall << global_symbol << "("; scall << global_symbol << "(";
for (size_t j = 0; j < sargs.size(); ++j) { for (size_t j = 0; j < sargs.size(); ++j) {
if (j > 0) scall << ", "; if (j > 0)
scall << ", ";
PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
} }
scall << ")"; scall << ")";
...@@ -623,13 +660,16 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co ...@@ -623,13 +660,16 @@ void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol, co
} }
os << sret; os << sret;
} else { } else {
CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
os);
} }
} }
// Print a reference expression to a buffer. // Print a reference expression to a buffer.
std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) { std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
const VarNode* buffer_var = buffer->data.get(); const BufferNode *buffer,
PrimExpr index) {
const VarNode *buffer_var = buffer->data.get();
std::ostringstream os; std::ostringstream os;
std::string vid = GetVarID(buffer_var); std::string vid = GetVarID(buffer_var);
std::string scope; std::string scope;
...@@ -685,12 +725,13 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffe ...@@ -685,12 +725,13 @@ std::string CodeGenTileLangHIP::GetBufferRef(DataType t, const BufferNode* buffe
return os.str(); return os.str();
} }
void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) { auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
this->PrintIndent(); this->PrintIndent();
this->stream << name << "("; this->stream << name << "(";
for (size_t i = offset; i < op->args.size(); i++) { for (size_t i = offset; i < op->args.size(); i++) {
if (i > offset) this->stream << ", "; if (i > offset)
this->stream << ", ";
this->stream << this->PrintExpr(op->args[i]); this->stream << this->PrintExpr(op->args[i]);
} }
this->stream << ");\n"; this->stream << ");\n";
...@@ -701,16 +742,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -701,16 +742,18 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]); std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated cp.async // use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) { if (op->args.size() == 5) {
this->PrintIndent(); this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< "+" << src_offset << ");\n"; << dst_offset << ", " << src << "+" << src_offset << ");\n";
} else { } else {
std::string condition = this->PrintExpr(op->args[5]); std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent(); this->PrintIndent();
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< ", " << src << "+" << src_offset << ", " << condition << ");\n"; << "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
} }
} else if (op->op.same_as(builtin::ptx_commit_group())) { } else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit"); print_extern_call_stmt("tl::cp_async_commit");
...@@ -722,7 +765,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -722,7 +765,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent(); this->PrintIndent();
int barrier_count = Downcast<IntImm>(op->args[0])->value; int barrier_count = Downcast<IntImm>(op->args[0])->value;
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
this->stream << "__shared__ uint64_t " << barrier_name << "[" << barrier_count << "];\n"; this->stream << "__shared__ uint64_t " << barrier_name << "["
<< barrier_count << "];\n";
} else if (op->op.same_as(tl::GetMBarrierOp())) { } else if (op->op.same_as(tl::GetMBarrierOp())) {
std::string barrier_name = "_mbarrier"; std::string barrier_name = "_mbarrier";
std::string barrier_id = this->PrintExpr(op->args[0]); std::string barrier_id = this->PrintExpr(op->args[0]);
...@@ -751,13 +795,15 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -751,13 +795,15 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans"; if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::STMatrixOp())) { } else if (op->op.same_as(tl::STMatrixOp())) {
int trans = Downcast<IntImm>(op->args[0])->value; int trans = Downcast<IntImm>(op->args[0])->value;
int num = Downcast<IntImm>(op->args[1])->value; int num = Downcast<IntImm>(op->args[1])->value;
std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
if (trans == 1) func_name += "_trans"; if (trans == 1)
func_name += "_trans";
print_extern_call_stmt(func_name, 2); print_extern_call_stmt(func_name, 2);
} else if (op->op.same_as(tl::FenceProxyAsyncOp())) { } else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
print_extern_call_stmt("tl::fence_proxy_async"); print_extern_call_stmt("tl::fence_proxy_async");
...@@ -765,15 +811,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -765,15 +811,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintIndent(); this->PrintIndent();
int nreg = Downcast<IntImm>(op->args[0])->value; int nreg = Downcast<IntImm>(op->args[0])->value;
int is_inc = Downcast<IntImm>(op->args[1])->value; int is_inc = Downcast<IntImm>(op->args[1])->value;
std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; std::string func_name =
is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
} else if (op->op.same_as(tl::WaitWgmma())) { } else if (op->op.same_as(tl::WaitWgmma())) {
this->PrintIndent(); this->PrintIndent();
int num_mma = Downcast<IntImm>(op->args[0])->value; int num_mma = Downcast<IntImm>(op->args[0])->value;
this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
} else if (op->op.same_as(tl::PackB16Op())) { } else if (op->op.same_as(tl::PackB16Op())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U); ICHECK_EQ(op->args.size(), 6U);
...@@ -807,7 +854,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -807,7 +854,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[4], os); this->PrintExpr(op->args[4], os);
os << "], "; os << "], ";
this->PrintExpr(op->args[6], os); this->PrintExpr(op->args[6], os);
if (const StringImmNode* str = op->args[7].as<StringImmNode>()) { if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
os << ", nvcuda::wmma::mem_" << str->value; os << ", nvcuda::wmma::mem_" << str->value;
} else { } else {
LOG(FATAL) << "Invalid parameters"; LOG(FATAL) << "Invalid parameters";
...@@ -833,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -833,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os); this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")"); os << "]" << ((i < 3) ? ", " : ")");
} }
}else if (op->op.same_as(builtin::tvm_mfma())) { } else if (op->op.same_as(builtin::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype} // arg 0: prefix: {otype}_16x16x16{itype}
// arg 1: A layout: row/col // arg 1: A layout: row/col
// arg 2: B layout: row/col // arg 2: B layout: row/col
...@@ -847,7 +894,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -847,7 +894,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 10: C accumulator // arg 10: C accumulator
// arg 11: C accumulator index // arg 11: C accumulator index
ICHECK(op->args.size() == 12U) << "Invalid number of arguments for tvm_mfma"; ICHECK(op->args.size() == 12U)
<< "Invalid number of arguments for tvm_mfma";
std::string prefix = Downcast<StringImm>(op->args[0])->value; std::string prefix = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value; std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value; std::string B_layout = Downcast<StringImm>(op->args[2])->value;
...@@ -860,7 +908,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -860,7 +908,8 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string b_bias = this->PrintExpr(op->args[9]); std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]); std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]); std::string c_bias = this->PrintExpr(op->args[11]);
ICHECK(A_layout == "row" || B_layout == "row") << "Matrix core only support row major"; ICHECK(A_layout == "row" || B_layout == "row")
<< "Matrix core only support row major";
// map for dtype -> float32x4 -> float4 // map for dtype -> float32x4 -> float4
std::unordered_map<std::string, std::string> dtype_map = { std::unordered_map<std::string, std::string> dtype_map = {
{"int8", "char"}, {"int8", "char"},
...@@ -873,8 +922,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -873,8 +922,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
{"float16x4", "float16x4"}, {"float16x4", "float16x4"},
{"bfloat16x4", "bfloat16x4"}, {"bfloat16x4", "bfloat16x4"},
{"float32x4", "float32x4"}, {"float32x4", "float32x4"},
{"float32x16", "float32x16"} {"float32x16", "float32x16"}};
};
std::string call_mfma_code = R"({ std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}), *((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}), *((({B_dytpe}*){b_ref}) + {b_bias}),
...@@ -893,15 +941,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) { ...@@ -893,15 +941,16 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode* op, std::ostream& os) {
replacer.register_rule("{c_ref}", c_ref); replacer.register_rule("{c_ref}", c_ref);
replacer.register_rule("{c_bias}", c_bias); replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mfma_code); os << replacer.rewrite(call_mfma_code);
} else { } else {
CodeGenC::VisitExpr_(op, os); CodeGenC::VisitExpr_(op, os);
} }
} }
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) { void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
if (op->attr_key == tir::attr::async_commit_queue_scope) { if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>(); const IntImmNode *queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body); this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream); this->VisitExpr(commit_group, this->stream);
...@@ -909,9 +958,11 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) { ...@@ -909,9 +958,11 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
} else if (op->attr_key == tir::attr::async_wait_queue_scope) { } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op); auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>(); auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; ICHECK(queue_id && queue_id->value == 0)
<< "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second; auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); auto wait_group =
Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream); this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>(); auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner); ICHECK(inner);
...@@ -919,7 +970,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) { ...@@ -919,7 +970,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
return; return;
} else if (op->attr_key == "threadblock_swizzle_pattern") { } else if (op->attr_key == "threadblock_swizzle_pattern") {
this->PrintIndent(); this->PrintIndent();
const StringImmNode* pattern = op->value.as<StringImmNode>(); const StringImmNode *pattern = op->value.as<StringImmNode>();
ICHECK(pattern); ICHECK(pattern);
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body); this->VisitStmt(op->body);
...@@ -928,7 +979,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) { ...@@ -928,7 +979,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode* op) {
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) { void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition)); ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
...@@ -941,7 +992,8 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) { ...@@ -941,7 +992,8 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
stream << ' ' << vid << "[];\n"; stream << ' ' << vid << "[];\n";
} else { } else {
size_t constant_size = op->ConstantAllocationSize(); size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && op->dtype == DataType::Int(1)) &&
...@@ -955,7 +1007,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) { ...@@ -955,7 +1007,7 @@ void CodeGenTileLangHIP::VisitStmt_(const AllocateNode* op) {
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) { void CodeGenTileLangHIP::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_"; os << "(make_";
...@@ -964,16 +1016,19 @@ void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) { ...@@ -964,16 +1016,19 @@ void CodeGenTileLangHIP::VisitExpr_(const RampNode* op, std::ostream& os) {
for (int i = 0; i < lanes; i++) { for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")"; << "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != lanes - 1) os << ", "; if (i != lanes - 1)
os << ", ";
} }
os << "))"; os << "))";
} }
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
lanes == 4) {
// make_int8x4 // make_int8x4
const int64_t* p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
ICHECK(p); ICHECK(p);
int64_t v = *p & 0xFF; int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v; v = (v << 24) | (v << 16) | (v << 8) | v;
...@@ -991,7 +1046,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -991,7 +1046,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "__pack_half2(" << v << ", " << v << ")"; os << "__pack_half2(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
...@@ -1004,18 +1060,21 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -1004,18 +1060,21 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
return; return;
} }
if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) { if (op->dtype.is_float() && op->dtype.bits() == 32 &&
op->dtype.lanes() == 8) {
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << "make_ulonglong4("; os << "make_ulonglong4(";
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
} }
os << ')'; os << ')';
...@@ -1024,7 +1083,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -1024,7 +1083,7 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false; bool fail = false;
const int64_t* p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
ICHECK(p); ICHECK(p);
int64_t v = *p & 0xF; int64_t v = *p & 0xF;
...@@ -1036,7 +1095,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -1036,7 +1095,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
os << "(int16_t)" << v; os << "(int16_t)" << v;
} }
} else { } else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
(v << 4) | v;
if (lanes == 8) { if (lanes == 8) {
if (op->dtype.is_uint()) { if (op->dtype.is_uint()) {
os << "(uint)" << v; os << "(uint)" << v;
...@@ -1048,7 +1108,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -1048,7 +1108,8 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 8; ++i) { for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
if (op->dtype.is_uint()) { if (op->dtype.is_uint()) {
os << "(uint)" << v; os << "(uint)" << v;
} else { } else {
...@@ -1071,13 +1132,15 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { ...@@ -1071,13 +1132,15 @@ void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes; ++i) { for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", "; if (i != 0)
os << ", ";
os << v; os << v;
} }
os << ')'; os << ')';
} }
inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p) { // NOLINT(*) inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p) { // NOLINT(*)
// Type code is kBFloat // Type code is kBFloat
if (op->dtype.is_bfloat16()) { if (op->dtype.is_bfloat16()) {
os << "bfloat16_t"; os << "bfloat16_t";
...@@ -1086,46 +1149,50 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang ...@@ -1086,46 +1149,50 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLang
} }
// Type code is kFloat // Type code is kFloat
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
case 64: case 64:
case 32: { case 32: {
std::ostringstream temp; std::ostringstream temp;
if (std::isinf(op->value)) { if (std::isinf(op->value)) {
if (op->value < 0) { if (op->value < 0) {
temp << "-"; temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) temp << 'f';
} }
p->MarkConst(temp.str()); temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
os << temp.str(); } else if (std::isnan(op->value)) {
break; temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
} } else {
case 16: { temp << std::scientific << op->value;
os << "half_t" << '('; if (op->dtype.bits() == 32)
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); temp << 'f';
PrintConst(const_f32.get(), os, p);
os << ')';
break;
} }
default: p->MarkConst(temp.str());
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
} }
} }
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
PrintConst(op, os, this); PrintConst(op, os, this);
} }
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void CodeGenTileLangHIP::HandleVolatileLoads(const std::string &value,
std::ostream& os) { const BufferLoadNode *op,
std::ostream &os) {
// Cast away volatile qualifier for fp16 types. That is, only loads and // Cast away volatile qualifier for fp16 types. That is, only loads and
// stores are volatile. The loaded objects are not marked as volatile. // stores are volatile. The loaded objects are not marked as volatile.
// //
if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
IsVolatile(op->buffer->data.get())) {
os << "("; os << "(";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << ")(" << value << ")"; os << ")(" << value << ")";
...@@ -1134,15 +1201,17 @@ void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const Buf ...@@ -1134,15 +1201,17 @@ void CodeGenTileLangHIP::HandleVolatileLoads(const std::string& value, const Buf
} }
} }
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i,
std::ostream& os) { const std::string &value,
std::ostream &os) {
ICHECK_GT(t.lanes(), 1); ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (!(t.lanes() == 2 || t.lanes() == 3)) { if (!(t.lanes() == 2 || t.lanes() == 3)) {
if (i != 0) { if (i != 0) {
os << "|"; os << "|";
} }
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
<< "))";
return; return;
} }
} }
...@@ -1199,7 +1268,7 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::stri ...@@ -1199,7 +1268,7 @@ void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i, const std::stri
return; return;
} }
void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) { void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
// clear previous generated state. // clear previous generated state.
this->InitFuncState(f); this->InitFuncState(f);
// reserve keywords // reserve keywords
...@@ -1218,10 +1287,11 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) { ...@@ -1218,10 +1287,11 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i]; tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", "; if (i != 0)
stream << ", ";
if (v.dtype().is_handle()) { if (v.dtype().is_handle()) {
// work around for grid constant parameters. // work around for grid constant parameters.
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (ptr->storage_scope == "grid_constant") { if (ptr->storage_scope == "grid_constant") {
stream << "__grid_constant__ const "; stream << "__grid_constant__ const ";
CodeGenC::PrintType(ptr->element_type, stream); CodeGenC::PrintType(ptr->element_type, stream);
...@@ -1236,8 +1306,8 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) { ...@@ -1236,8 +1306,8 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
} }
CodeGenC::PrintType(GetType(v), stream); CodeGenC::PrintType(GetType(v), stream);
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype); RegisterHandleType(v.get(), prim->dtype);
} }
} }
...@@ -1259,5 +1329,5 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) { ...@@ -1259,5 +1329,5 @@ void CodeGenTileLangHIP::AddFunction(const PrimFunc& f) {
this->stream << "}\n\n"; this->stream << "}\n\n";
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -21,50 +21,58 @@ namespace tvm { ...@@ -21,50 +21,58 @@ namespace tvm {
namespace codegen { namespace codegen {
class CodeGenTileLangHIP final : public CodeGenC { class CodeGenTileLangHIP final : public CodeGenC {
public: public:
CodeGenTileLangHIP(); CodeGenTileLangHIP();
std::string Finish(); std::string Finish();
// override behavior // override behavior
void PrintFuncPrefix(std::ostream& os) final; void PrintFuncPrefix(std::ostream &os) final;
void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; void PrintExtraAttrs(const PrimFunc &f, std::ostream &os) final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode *op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode *op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string &scope,
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs,
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) PrimExpr rhs,
void PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream &os) final; // NOLINT(*)
std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void PrintVecElemLoad(const std::string &vec, DataType t, int i,
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; void PrintVecElemStore(const std::string &vec, DataType t, int i,
std::string CastFromTo(std::string value, DataType from, DataType target) final; const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final;
std::string CastFromTo(std::string value, DataType from,
DataType target) final;
// overload visitor // overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const RampNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitStmt_(const AllocateNode* op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode* op) final; void VisitStmt_(const AttrStmtNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const PrimFunc& f); void AddFunction(const PrimFunc &f);
protected: protected:
virtual std::string GetBufferRef(DataType t, const BufferNode* buffer, PrimExpr index) final; virtual std::string GetBufferRef(DataType t, const BufferNode *buffer,
void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, PrimExpr index) final;
bool skip_first_arg, std::ostream& os) final; // NOLINT(*) void PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, bool skip_first_arg,
std::ostream &os) final; // NOLINT(*)
private: private:
// Handle volatile loads // Handle volatile loads
void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, void HandleVolatileLoads(const std::string &value, const BufferLoadNode *op,
std::ostream& os) final; std::ostream &os) final;
// Whether scope such as "__shared__" or "__constant__" is part of type. // Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; } bool IsScopePartOfType() const final { return false; }
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenTileLangHIP* p); friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangHIP *p);
// whether need math_constants.h // whether need math_constants.h
bool need_math_constants_h_{false}; bool need_math_constants_h_{false};
...@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC { ...@@ -83,7 +91,7 @@ class CodeGenTileLangHIP final : public CodeGenC {
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_CODEGEN_HIP_H_ #endif // TVM_TL_TARGET_CODEGEN_HIP_H_
This source diff could not be displayed because it is too large. You can view the blob instead.
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
#include "runtime/cuda/cuda_module.h"
#include "codegen_cuda.h" #include "codegen_cuda.h"
#include "runtime/cuda/cuda_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) { static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap; std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second); auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info; runtime::FunctionInfo info;
...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co ...@@ -26,7 +28,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -43,7 +45,8 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { ...@@ -51,14 +54,15 @@ runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { if (const auto *f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code, target).operator std::string(); ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "cubin"; if (ptx[0] != '/')
fmt = "cubin";
} else { } else {
ICHECK(0); ICHECK(0);
} }
...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -72,7 +76,8 @@ String BuildTLDebug(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) { ...@@ -80,14 +85,16 @@ String BuildTLDebug(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto *f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
return String(code); return String(code);
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda").set_body_typed(BuildTileLangCUDA); TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen").set_body_typed(BuildTLDebug); .set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
.set_body_typed(BuildTLDebug);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
#if defined(__linux__) #if defined(__linux__)
#include <sys/stat.h> #include <sys/stat.h>
#endif #endif
...@@ -8,28 +8,28 @@ ...@@ -8,28 +8,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hiprtc.h> #include <hip/hiprtc.h>
#include "runtime/rocm/rocm_module.h"
#include "codegen_hip.h" #include "codegen_hip.h"
#include "runtime/rocm/rocm_module.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
#define HIPRTC_CALL(x) \
#define HIPRTC_CALL(x) \
\ \
{ \ { \
\ \
hiprtcResult result = x; \ hiprtcResult result = x; \
\ \
if (result != HIPRTC_SUCCESS) { \ if (result != HIPRTC_SUCCESS) { \
\ \
LOG(FATAL) \ LOG(FATAL) \
<< "HiprtcError: " #x " failed with error: " << hiprtcGetErrorString(result); \ << "HiprtcError: " #x " failed with error: " \
<< hiprtcGetErrorString(result); \
\ \
\ \
} \ } \
\ \
\ \
} }
static std::string FindHIPIncludePath() { static std::string FindHIPIncludePath() {
...@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() { ...@@ -39,7 +39,7 @@ static std::string FindHIPIncludePath() {
const std::string delimiter = "/"; const std::string delimiter = "/";
#endif #endif
std::string hip_include_path; std::string hip_include_path;
const char* hip_path_env = std::getenv("HIP_PATH"); const char *hip_path_env = std::getenv("HIP_PATH");
if (hip_path_env != nullptr) { if (hip_path_env != nullptr) {
hip_include_path += hip_path_env; hip_include_path += hip_path_env;
hip_include_path += delimiter + "include"; hip_include_path += delimiter + "include";
...@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() { ...@@ -58,19 +58,24 @@ static std::string FindHIPIncludePath() {
} }
#endif #endif
LOG(FATAL) << "Cannot find HIP include path." LOG(FATAL) << "Cannot find HIP include path."
<< "HIP_PATH is not set or ROCm is not installed in the default installation path." << "HIP_PATH is not set or ROCm is not installed in the default "
"installation path."
<< "In other than linux, it is necessary to set HIP_PATH."; << "In other than linux, it is necessary to set HIP_PATH.";
return hip_include_path; return hip_include_path;
} }
static std::string HIPRTCCompile(const std::string& code, bool include_path = false) { static std::string HIPRTCCompile(const std::string &code,
bool include_path = false) {
std::vector<std::string> compile_params; std::vector<std::string> compile_params;
std::vector<const char*> param_cstrings{}; std::vector<const char *> param_cstrings{};
hiprtcProgram prog; hiprtcProgram prog;
std::string cc = "gfx900"; // Default target architecture (can be changed as needed) std::string cc =
"gfx900"; // Default target architecture (can be changed as needed)
int major, minor; int major, minor;
hipError_t e1 = hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, 0); hipError_t e1 = hipDeviceGetAttribute(
hipError_t e2 = hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, 0); &major, hipDeviceAttributeComputeCapabilityMajor, 0);
hipError_t e2 = hipDeviceGetAttribute(
&minor, hipDeviceAttributeComputeCapabilityMinor, 0);
if (e1 == hipSuccess && e2 == hipSuccess) { if (e1 == hipSuccess && e2 == hipSuccess) {
cc = "gfx" + std::to_string(major * 100 + minor * 10); cc = "gfx" + std::to_string(major * 100 + minor * 10);
...@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa ...@@ -86,10 +91,11 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
compile_params.push_back(include_option); compile_params.push_back(include_option);
} }
for (const auto& string : compile_params) { for (const auto &string : compile_params) {
param_cstrings.push_back(string.c_str()); param_cstrings.push_back(string.c_str());
} }
HIPRTC_CALL(hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); HIPRTC_CALL(
hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
hiprtcResult compile_res = hiprtcResult compile_res =
hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());
...@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa ...@@ -110,11 +116,13 @@ static std::string HIPRTCCompile(const std::string& code, bool include_path = fa
return code_out; return code_out;
} }
static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(const IRModule& mod) { static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap; std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<tir::PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs"; ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second); auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info; runtime::FunctionInfo info;
...@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co ...@@ -129,7 +137,7 @@ static std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
info.arg_types.push_back(f->params[i].dtype()); info.arg_types.push_back(f->params[i].dtype());
} }
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) { if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto& tag : opt.value()) { for (const auto &tag : opt.value()) {
info.launch_param_tags.push_back(tag); info.launch_param_tags.push_back(tag);
} }
} }
...@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -146,7 +154,8 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
cg.Init(output_ssa); cg.Init(output_ssa);
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenTileLangHIP: Can only take PrimFunc"; ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangHIP: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second); auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
...@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) { ...@@ -154,21 +163,23 @@ runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) { if (const auto *f = Registry::Get("tvm_callback_hip_postproc")) {
code = (*f)(code, target).operator std::string(); code = (*f)(code, target).operator std::string();
} }
std::string fmt = "ptx"; std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_hip_compile")) { if (const auto *f = Registry::Get("tvm_callback_hip_compile")) {
ptx = (*f)(code, target).operator std::string(); ptx = (*f)(code, target).operator std::string();
if (ptx[0] != '/') fmt = "hsaco"; if (ptx[0] != '/')
fmt = "hsaco";
} else { } else {
ptx = HIPRTCCompile(code, false); ptx = HIPRTCCompile(code, false);
} }
return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
} }
TVM_REGISTER_GLOBAL("target.build.tilelang_hip").set_body_typed(BuildTileLangHIP); TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
.set_body_typed(BuildTileLangHIP);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -11,13 +11,17 @@ ...@@ -11,13 +11,17 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
bool TargetIsCuda(Target target) { return target->GetTargetDeviceType() == kDLCUDA; } bool TargetIsCuda(Target target) {
bool TargetIsRocm(Target target) { return target->GetTargetDeviceType() == kDLROCM; } return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
return target->GetTargetDeviceType() == kDLROCM;
}
int GetArchInt(Target target) { int GetArchInt(Target target) {
auto s = target->GetAttr<String>("arch"); auto s = target->GetAttr<String>("arch");
ICHECK(s.defined()); ICHECK(s.defined());
const char* arch_str = s.value().c_str(); const char *arch_str = s.value().c_str();
ICHECK_EQ(arch_str[0], 's'); ICHECK_EQ(arch_str[0], 's');
ICHECK_EQ(arch_str[1], 'm'); ICHECK_EQ(arch_str[1], 'm');
ICHECK_EQ(arch_str[2], '_'); ICHECK_EQ(arch_str[2], '_');
...@@ -25,31 +29,36 @@ int GetArchInt(Target target) { ...@@ -25,31 +29,36 @@ int GetArchInt(Target target) {
} }
bool TargetIsVolta(Target target) { bool TargetIsVolta(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 70 && arch < 75; return arch >= 70 && arch < 75;
} }
bool TargetIsTuring(Target target) { bool TargetIsTuring(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75 && arch < 80; return arch >= 75 && arch < 80;
} }
bool TargetIsAmpere(Target target) { bool TargetIsAmpere(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 80 && arch < 90; return arch >= 80 && arch < 90;
} }
bool TargetIsHopper(Target target) { bool TargetIsHopper(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
bool TargetIsCDNA(Target target) { bool TargetIsCDNA(Target target) {
if (!TargetIsRocm(target)) return false; if (!TargetIsRocm(target))
return false;
if (target->attrs.count("mcpu")) { if (target->attrs.count("mcpu")) {
std::string mcpu = Downcast<String>(target->attrs.at("mcpu")); std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
// if mcpu start with "gfx9", it is CDNA // if mcpu start with "gfx9", it is CDNA
...@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) { ...@@ -78,16 +87,18 @@ bool TargetHasAsyncCopy(Target target) {
return false; return false;
} }
bool TargetHasLdmatrix(Target target) { bool TargetHasLdmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 75; return arch >= 75;
} }
bool TargetHasStmatrix(Target target) { bool TargetHasStmatrix(Target target) {
if (!TargetIsCuda(target)) return false; if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target); int arch = GetArchInt(target);
return arch >= 90; return arch >= 90;
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target); ...@@ -23,12 +23,12 @@ bool TargetIsTuring(Target target);
bool TargetIsAmpere(Target target); bool TargetIsAmpere(Target target);
bool TargetIsHopper(Target target); bool TargetIsHopper(Target target);
bool TargetIsCDNA(Target target); bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target); bool TargetHasStmatrix(Target target);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
#endif // TVM_TL_TARGET_UTILS_H_ #endif // TVM_TL_TARGET_UTILS_H_
...@@ -25,56 +25,57 @@ using cutlass::tfloat32_t; ...@@ -25,56 +25,57 @@ using cutlass::tfloat32_t;
// Pack two half values. // Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) { TL_DEVICE unsigned __pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two half_t values. // Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) { TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two bfloat16_t values. // Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short*)&x); unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short*)&y); unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
/// Helper to cast SMEM pointer to unsigned /// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const* const ptr) { TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr)); return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t val) { TL_DEVICE void atomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t* address, half_t* val) { TL_DEVICE void atomicAdd(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half*>(address), static_cast<half>(*val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAddx2(half_t* address, half_t* val) { TL_DEVICE void atomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2*>(address), static_cast<half2>(*reinterpret_cast<half2*>(val))); atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
TL_DEVICE void atomicAdd(half_t* address, float val) { TL_DEVICE void atomicAdd(half_t *address, float val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half*>(address), __float2half(val)); atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
} }
// DP4A // DP4A
template<typename InDatatype, typename OutDatatype> template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype* a, InDatatype* b, OutDatatype* c) { TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
const int a_int = *((int*)a); const int a_int = *((int *)a);
const int b_int = *((int*)b); const int b_int = *((int *)b);
const int c_int = *((int*)c); const int c_int = *((int *)c);
*c = __dp4a(a_int, b_int, c_int); *c = __dp4a(a_int, b_int, c_int);
} }
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
namespace tl { namespace tl {
TL_DEVICE void cp_async_commit() { asm volatile("cp.async.commit_group;\n" ::); } TL_DEVICE void cp_async_commit() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int N> template <int N> TL_DEVICE void cp_async_wait() {
TL_DEVICE void cp_async_wait() {
if constexpr (N == 0) { if constexpr (N == 0) {
asm volatile("cp.async.wait_all;\n" ::); asm volatile("cp.async.wait_all;\n" ::);
} else { } else {
...@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() { ...@@ -22,7 +23,7 @@ TL_DEVICE void cp_async_wait() {
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4); static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr); unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) { if constexpr (N == 16) {
...@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { ...@@ -33,7 +34,7 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.cg.shared.global [%0], [%1], %2;" "cp.async.cg.shared.global [%0], [%1], %2;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N)); "l"((void *)(global_ptr)), "n"(N));
} else { } else {
__asm__ __volatile__( __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH #if TL_ENABLE_L2_PREFETCH
...@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) { ...@@ -42,12 +43,13 @@ TL_DEVICE void cp_async_gs(void const* const smem_addr, void* global_ptr) {
"cp.async.ca.shared.global [%0], [%1], %2;" "cp.async.ca.shared.global [%0], [%1], %2;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N)); "l"((void *)(global_ptr)), "n"(N));
} }
} }
template <int N> template <int N>
TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global_ptr, bool cond) { TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
void *global_ptr, bool cond) {
static_assert(N == 16 || N == 8 || N == 4); static_assert(N == 16 || N == 8 || N == 4);
int bytes = cond ? N : 0; int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr); unsigned int addr = smem_ptr_to_uint(smem_addr);
...@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global ...@@ -59,7 +61,7 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.cg.shared.global [%0], [%1], %2, %3;" "cp.async.cg.shared.global [%0], [%1], %2, %3;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes)); "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} else { } else {
__asm__ __volatile__( __asm__ __volatile__(
#if TL_ENABLE_L2_PREFETCH #if TL_ENABLE_L2_PREFETCH
...@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global ...@@ -68,8 +70,8 @@ TL_DEVICE void cp_async_gs_conditional(void const* const smem_addr, void* global
"cp.async.ca.shared.global [%0], [%1], %2, %3;" "cp.async.ca.shared.global [%0], [%1], %2, %3;"
#endif #endif
::"r"(addr), ::"r"(addr),
"l"((void*)(global_ptr)), "n"(N), "r"(bytes)); "l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} }
} }
} // namespace tl } // namespace tl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment