Commit 57ab687c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Initialization] Migration of Codebase from Dev Branch into Main (#10)



* Add format.sh script for code formatting and linting

* docs update

* center align the title

* lint fix

* add ignore

* Add .gitignore for 3rdparty directory

* Add requirements-dev.txt, requirements-test.txt, and requirements.txt

* 3rdparty

* Add gemm.h, CMakeLists.txt, _ffi_api.py, __init__.py, runtime.h, reduce.h, loop_partition.h, utils.h, and loop_vectorize.h

* Refactor CMakeLists.txt and include statements

- Update CMakeLists.txt to use a newer version of CMake and add project name
- Remove unnecessary include directories

Fix include paths in layout.cc, codegen.cc, codegen.h, rt_mod.cc, frontend_legalize.cc, inject_pipeline.cc, layout_inference.cc, loop_vectorize.cc, and lower_tile_op.cc

- Update include paths to use relative paths instead of absolute paths

* Update submodule for 3rdparty/tvm

* update

* load dll first

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* git keep update

* Refactor CMakeLists.txt and include statements

* Refactor CMakeLists.txt and include statements

* refactor code structure

* Update Readme

* CMakeLists Customized

* update readme

* update README

* update readme

* update usage

* with TVM_IMPORT_PYTHON_PATH to handle own tvm build python import

* annotate lower transform global func with `transform` prefix

* Migrate Simplify Pass from tilelang tvm branch

* enhance system environment handling with __init__ and CMake

* Initial commit

* CODE_OF_CONDUCT.md committed

* LICENSE committed

* README.md committed

* SECURITY.md committed

* SUPPORT.md committed

* CODE_OF_CONDUCT Commit

* LICENSE Commit

* SECURITY Commit

* SUPPORT Commit

* Modify Support

* Update README.md

* security ci update

* remove examples

* Update and implement clang-format

* add composable kernel components

* Migrate from latest update

* submodule update

* Test update

* Update License

* Spell check

* lint fix

* add clang-tidy to apply static analysis for c source

* update tilelang examples

* Update Install Docs

* Refactor filetree

* Enhance Install

* conflict resloved

* annotate_version

* Initial Update

* test fix

* install

* Implement setup.py

* lint fix

* Separate Init

* Separate test

* docker file commit

* add logo

* Update Readme and Examples

* update readme

* update logo

* Implement AMD Installation

* Add License

* Update AMD MI300x Benchmark

* update README

* update mi300 benchmark scripts

* update ignore

* enhance build scirpt

* update image

* enhance setup.py to remove duplicated libraries

* remove debug files

* update readme

* update image

* update gemm examples

* update flashattention README

* readme update

* add cmake into requirements

* libinfo fix

* auto update submodule

* lint fix

* Fix AMD Build and Test

* Update check for transpose attribute for CDNA Arch

* typo fix for amd

* Implement Matmul Benchmark

* Refactor Code

* [TypoFix] Fix GEMM Example

* [Docs] Init Linear Attention README

* [TYPO] Typo fix

* [Lint] Lint Fix

* enhance example with intrinsics

* [Enhancement] Improve Buffer Collection during IR Parser

* [Dev] Introduce Current classmethod to get current frame

* submodule update

* fake test pass update

* support thread_extent_api

* code optimize

* Add GEMM function implementation for matrix multiplication

* Update logging format to reflect TileLang in logger messages

* Refactor CMakeLists.txt for improved readability and set default build type to Release

* Support Gemm SS Primitives Implementation

* [README] Upload Tile Language Logo (#5)

* update logo

* Update README.md to enhance formatting and center the title

---------
Co-authored-by: default avatarmicrosoft-github-operations[bot] <55726097+microsoft-github-operations[bot]@users.noreply.github.com>
Co-authored-by: default avatarMicrosoft Open Source <microsoftopensource@users.noreply.github.com>
Co-authored-by: default avatarYu Cheng <yu.cheng@pku.edu.cn>
parent 64f17c2f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file Layout.h
*
*/
#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_
#include <tvm/arith/analyzer.h>
namespace tvm {
namespace tl {
using namespace tir;
class Layout;
class Fragment;
class LayoutNode : public Object {
public:
LayoutNode() = default;
LayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
size_t InputDim() const { return input_size_.size(); }
size_t OutputDim() const { return forward_index_.size(); }
Array<PrimExpr> InputShape() const { return input_size_; }
Array<PrimExpr> OutputShape() const;
Array<PrimExpr> GetForwardIndex() const { return forward_index_; }
virtual Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const;
virtual Layout Inverse() const;
virtual void DebugOutput() const;
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "tl.Layout";
bool SEqualReduce(const LayoutNode* other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v);
TVM_DECLARE_BASE_OBJECT_INFO(LayoutNode, Object);
protected:
virtual Map<Var, Range> getVarMap() const;
void UpdateAnalyzer(arith::Analyzer* analyzer) const;
Array<PrimExpr> forward_index_;
Array<PrimExpr> input_size_;
};
/*!
* \brief Layout reference class.
*/
class Layout : public ObjectRef {
public:
TVM_DLL Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index);
TVM_DLL Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index);
TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode);
};
class FragmentNode : public LayoutNode {
public:
FragmentNode() = default;
FragmentNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, PrimExpr forward_thread,
PrimExpr replicate_size);
PrimExpr GetForwardThread() const { return forward_thread_; }
Layout Inverse() const final;
PrimExpr ThreadExtent() const;
PrimExpr ReplicateExtent() const { return replicate_size_; };
PrimExpr ForwardThread(const Array<PrimExpr>& vars, const Optional<PrimExpr>& rep_var) const;
Fragment Repeat(const Array<PrimExpr>& repeats, bool repeat_on_thread,
bool lower_dim_first = true) const;
Fragment Replicate(int repeats) const;
Fragment DeReplicate() const;
Fragment CondenseReplicateVar() const;
void DebugOutput() const final;
void VisitAttrs(tvm::AttrVisitor* v);
bool SEqualReduce(const FragmentNode* other, SEqualReducer equal) const;
static constexpr const char* _type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
protected:
Map<Var, Range> getVarMap() const final;
PrimExpr forward_thread_;
PrimExpr replicate_size_;
};
/*!
* \brief Fragment reference class.
*/
class Fragment : public Layout {
public:
TVM_DLL Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
PrimExpr forward_thread, IterVar thread_replicate);
TVM_DLL Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
PrimExpr forward_thread, PrimExpr replicate_size, Optional<Var> replicate_var);
TVM_DEFINE_OBJECT_REF_METHODS(Fragment, Layout, FragmentNode);
};
Var InputPlaceholder(size_t idx);
Var ReplicationPlaceholder();
Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
Fragment makeGemmFragmentC(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, const int element_size);
Fragment makeGemmFragmentB(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n, bool transposed = false);
// Default Memory Layout
Layout makeGemmLayoutLinear(int stride, int continuous);
Layout makeGemmABLayoutPadded(int stride, int continuous, int element_size);
Layout makeGemmABLayout(int stride, int continuous, int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size, int kfactor);
Fragment makeGemmVoltaFragmentC(const int block_m, const int block_n, const int warp_m,
const int warp_n, const int element_size);
Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, const int block_k,
const int warp_m, const int warp_n);
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, int kfactor);
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char* kLayoutMap = "layout_map";
} // namespace attr
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_LAYOUT_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file layout/swizzle.cc
* \brief Define swizzled layout
*
*/
#include "swizzle.h"
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <cmath>
namespace tvm {
namespace tl {
SwizzlePattern::SwizzlePattern(int bits, int base, int shift)
: bits_(bits), base_(base), shift_(shift) {
ICHECK(bits >= 0);
ICHECK(base >= 0);
ICHECK(shift >= 0);
ICHECK(shift >= bits);
}
PrimExpr SwizzlePattern::swizzle(PrimExpr expr) const {
int base = (1 << base_);
int mask = ((1 << bits_) - 1) << shift_;
PrimExpr high = FloorDiv(expr, base);
PrimExpr low = FloorMod(expr, base);
high = bitwise_xor(high, right_shift(bitwise_and(high, mask), shift_));
return low + high * base;
}
bool SwizzlePattern::operator==(const SwizzlePattern& other) const {
return std::tie(base_, bits_, shift_) == std::tie(other.base_, other.bits_, other.shift_);
}
SwizzledLayoutNode::SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern)
: pattern_(pattern) {
input_size_ = input_size;
arith::Analyzer analyzer;
UpdateAnalyzer(&analyzer);
forward_index_ = forward_index.Map([&](const PrimExpr& e) { return analyzer.Simplify(e); });
}
Array<PrimExpr> SwizzledLayoutNode::Forward(const Array<PrimExpr>& vars) const {
auto expr_list = LayoutNode::Forward(vars);
auto expr = expr_list.back();
expr_list.pop_back();
expr_list.push_back(pattern_.swizzle(expr));
return expr_list;
}
void SwizzledLayoutNode::DebugOutput() const {
LayoutNode::DebugOutput();
std::cout << "Layout Swizzle: " << pattern_.Base() << " " << pattern_.Bits() << " "
<< pattern_.Shift();
}
Layout SwizzledLayoutNode::Inverse() const {
ICHECK(0) << "Not Implemented.";
return {};
}
SwizzledLayout::SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
Map<Var, PrimExpr> vmap;
Array<PrimExpr> input_size;
for (size_t i = 0; i < forward_var.size(); i++) {
vmap.Set(forward_var[i]->var, InputPlaceholder(i));
CHECK(is_zero(forward_var[i]->dom->min));
input_size.push_back(forward_var[i]->dom->extent);
}
forward_index = forward_index.Map([&](const PrimExpr& e) { return Substitute(e, vmap); });
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n);
}
SwizzledLayout::SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern) {
auto n = make_object<SwizzledLayoutNode>(input_size, forward_index, pattern);
data_ = std::move(n);
}
void SwizzledLayoutNode::VisitAttrs(tvm::AttrVisitor* v) { LayoutNode::VisitAttrs(v); }
bool SwizzledLayoutNode::SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const {
return equal(this->InputShape(), other->InputShape()) &&
equal(this->forward_index_, other->forward_index_) && pattern_ == other->pattern_;
}
TVM_REGISTER_NODE_TYPE(SwizzledLayoutNode);
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file swizzle.h
* \brief Define swizzled layout
*
*/
#ifndef TVM_TL_LAYOUT_SWIZZLE_H_
#define TVM_TL_LAYOUT_SWIZZLE_H_
#include "layout.h"
namespace tvm {
namespace tl {
/*!
* \brief Swizzle pattern
*/
class SwizzlePattern {
public:
SwizzlePattern() = default;
SwizzlePattern(int bits, int base, int shift);
PrimExpr swizzle(PrimExpr expr) const;
int Bits() const { return bits_; }
int Base() const { return base_; }
int Shift() const { return shift_; }
bool operator==(const SwizzlePattern& other) const;
private:
int bits_;
int base_;
int shift_;
};
/*!
* \brief Layout with swizzle
*/
class SwizzledLayoutNode : public LayoutNode {
public:
SwizzledLayoutNode() = default;
SwizzledLayoutNode(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
Array<PrimExpr> Forward(const Array<PrimExpr>& vars) const final;
Layout Inverse() const final;
void DebugOutput() const final;
static constexpr const char* _type_key = "tl.SwizzledLayout";
bool SEqualReduce(const SwizzledLayoutNode* other, SEqualReducer equal) const;
void VisitAttrs(tvm::AttrVisitor* v);
TVM_DECLARE_FINAL_OBJECT_INFO(SwizzledLayoutNode, LayoutNode);
private:
SwizzlePattern pattern_;
};
/*!
* \brief SwizzledLayout reference class.
*/
class SwizzledLayout : public Layout {
public:
TVM_DLL SwizzledLayout(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
TVM_DLL SwizzledLayout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
SwizzlePattern pattern);
TVM_DEFINE_OBJECT_REF_METHODS(SwizzledLayout, Layout, SwizzledLayoutNode);
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_SWIZZLE_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file layout/utils.cc
* \brief Some arith tools for layout & fragment inference
*
*/
#include "utils.h"
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tl {
using namespace tir;
using namespace arith;
bool CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) {
const auto* clhs = lhs.as<IntImmNode>();
const auto* crhs = rhs.as<IntImmNode>();
if (crhs && crhs->value == 0) {
return false;
} else if (clhs && crhs) {
return clhs->value % crhs->value == 0;
}
return false;
}
/*!
* \brief Collector that collects the outgoing split reference of each IterMark.
*
* These out-going splits can then be used to check if the iterators are independent.
*/
class IterMarkSplitCollector {
public:
// mark all IterMarks that are visited.
std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
// each iter mark to its outgoing splits that are referenced.
std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual>
mark2splits_;
/*!
* \brief Collect all mark2splits recursively from indices.
* \param indices The iterator of interest.
*/
void Collect(const Array<IterSumExpr>& indices) {
for (IterSumExpr sum_expr : indices) {
for (IterSplitExpr split : sum_expr->args) {
this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split);
}
}
}
void CollectInternal(const IterMark& mark) {
if (visited_.count(mark)) return;
visited_.insert(mark);
if (auto* op = mark->source.as<IterSumExprNode>()) {
for (IterSplitExpr split : op->args) {
this->CollectInternal(split->source);
mark2splits_[split->source].push_back(split);
}
}
}
};
Array<IterSplitExpr> get_unused_iters(const IterMark& mark,
const std::vector<IterSplitExpr>& splits,
Analyzer* analyzer) {
PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> results;
size_t i = 0;
for (; i < splits.size();) {
size_t j = 0;
size_t lowest = splits.size();
for (; j < splits.size(); ++j) {
if (used[j]) continue;
if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) {
break;
}
if (lowest == splits.size() ||
CanProveDivisible(splits[lowest]->lower_factor, splits[j]->lower_factor)) {
lowest = j;
}
}
if (j == splits.size()) {
ICHECK(lowest != splits.size());
ICHECK(CanProveDivisible(splits[lowest]->lower_factor, expected_lower_factor));
results.emplace_back(mark, expected_lower_factor,
FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1);
expected_lower_factor = splits[lowest]->lower_factor;
} else {
used[j] = true;
i++;
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}
}
bool match_full_iter = analyzer->CanProveEqual(expected_lower_factor, mark->extent);
if (!match_full_iter) {
results.emplace_back(mark, expected_lower_factor, FloorDiv(mark->extent, expected_lower_factor),
1);
}
return results;
}
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
const Array<IterVar> input_iters, Analyzer* analyzer) {
auto iter_sum = exprs.Map(
[&](const auto& e) { return NormalizeToIterSum(e, ToVMap(input_iters), analyzer); });
IterMarkSplitCollector collector;
collector.Collect(iter_sum);
Array<IterSplitExpr> results;
for (const IterMark& mark : collector.visited_) {
ICHECK(mark->source.as<Var>()) << "Not a normalized iterator: " << mark;
}
for (const IterVar& iter : input_iters) {
IterMark iv_mark;
for (const IterMark& mark : collector.visited_) {
if (mark->source.as<Var>().same_as(iter->var)) {
iv_mark = mark;
break;
}
}
if (iv_mark.defined()) {
auto splits = get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
// Put the small axis last
results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) {
auto mark = IterMark(iter->var, iter->dom->extent);
auto split = IterSplitExpr(mark, 1, iter->dom->extent, 1);
results.push_back(split);
}
}
return results;
}
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits) {
Array<arith::IterSplitExpr> lists;
PrimExpr scale = 1;
for (int i = splits.size() - 1; i >= 0; i--) {
auto scaled_split =
arith::IterSplitExpr(splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale);
lists.push_back(scaled_split);
scale *= splits[i]->extent;
}
return arith::NormalizeIterMapToExpr(arith::IterSumExpr(lists, 0));
}
class IterSumMutator {
public:
IterSumMutator(const Map<IterSplitExpr, IterSplitExpr>& replace_map)
: replace_map_(replace_map) {}
// override the original mutate function.
IterSumExpr Mutate(const IterSumExpr& iter_sum) {
Array<IterSplitExpr> args;
for (const auto& split : iter_sum->args) {
if (replace_map_.count(split)) {
args.push_back(replace_map_[split]);
} else {
auto split_ =
IterSplitExpr(Mutate(split->source), split->lower_factor, split->extent, split->scale);
args.push_back(split_);
}
}
return IterSumExpr(args, iter_sum->base);
}
IterMark Mutate(const IterMark& mark) {
if (auto* op = mark->source.as<IterSumExprNode>()) {
return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
} else {
return mark;
}
}
private:
Map<IterSplitExpr, IterSplitExpr> replace_map_;
};
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
const Array<IterVar> input_iters, const Var& var,
arith::Analyzer* analyzer) {
auto iter_sum = arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer);
IterMarkSplitCollector collector;
collector.Collect({iter_sum});
IterMark mark;
for (const IterMark& m : collector.visited_) {
ICHECK(m->source.as<Var>()) << "Not a normalized iterator: " << mark;
if (m->source.as<Var>().value().same_as(var)) {
mark = m;
break;
}
}
std::vector<tvm::arith::IterSplitExpr> splits;
if (mark.defined()) {
splits = collector.mark2splits_[mark];
}
PrimExpr extent = 1;
for (const auto& split : splits) {
extent *= split->extent;
}
extent = analyzer->Simplify(extent);
auto new_var = Var(var->name_hint, var->type_annotation);
auto new_iter_var = IterVar(Range(0, extent), new_var, IterVarType::kDataPar);
auto new_mark = IterMark(new_var, extent);
PrimExpr scale = 1;
Map<IterSplitExpr, IterSplitExpr> replace_map;
for (const auto& split : splits) {
auto rescaled = arith::IterSplitExpr(new_mark, scale, split->extent, split->scale);
replace_map.Set(split, rescaled);
scale *= split->extent;
}
IterSumMutator mutator(replace_map);
PrimExpr reaplced = analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum)));
return {reaplced, new_iter_var};
}
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap) {
Array<IterVar> result;
for (const auto& [var, range] : vmap) {
result.push_back(IterVar(range, var, IterVarType::kDataPar));
}
return result;
}
Map<Var, Range> ToVMap(const Array<IterVar>& ivs) {
Map<Var, Range> result;
for (const auto& iv : ivs) {
result.Set(iv->var, iv->dom);
}
return result;
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file layout/utils.h
* \brief Some arith tools for layout & fragment inference
*
*/
#ifndef TVM_TL_LAYOUT_UTILS_H_
#define TVM_TL_LAYOUT_UTILS_H_
#include <tvm/arith/iter_affine_map.h>
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Collect the IterSplit that is not used in expr.
*
* If the expr is (x // 2) and x is in Range(4),
* than the result should be (x % 2)
*/
Array<arith::IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr>& exprs,
const Array<IterVar> input_iters,
arith::Analyzer* analyzer);
/*!
* \brief Compress the iterator var, remove the unused part of the var not present in the expr
*
* Returns the compressed IterVar as well as the Updated iter sum expression.
*/
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr& expr,
const Array<IterVar> input_iters, const Var& var,
arith::Analyzer* analyzer);
/*!
* \brief Convert the iter splits returned by DivideUnusedIterators into flattened expression
*
*/
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr>& splits);
/*!
* \brief Convert an Array of IterVar to a Map object
*
*/
Map<Var, Range> ToVMap(const Array<IterVar>& ivs);
/*!
* \brief Convert a Map object to an Array of IterVar
*
*/
Array<IterVar> ToIterVars(const Map<Var, Range>& vmap);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_LAYOUT_UTILS_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/builtin.cc
* \brief Builtin intrinsics.
*
*/
#include "builtin.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h"
namespace tvm {
namespace tl {
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op& OpName() { \
static const Op& op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
TIR_DEFINE_TL_BUILTIN(CreateListofMBarrierOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(CreateTMADescriptorOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(CreateTMAIm2ColDescriptorOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(GetMBarrierOp)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(TMALoadOp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMALoadIm2ColOp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(TMAStoreOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierWaitParity)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(MBarrierExpectTX)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(LDMatrixOp)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(STMatrixOp)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SyncThreadsPartialOp)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(FenceProxyAsyncOp)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(PackB16Op).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/builtin.h
* \brief Builtin intrinsics.
*
*/
#ifndef TVM_TL_OP_BUILTIN_H_
#define TVM_TL_OP_BUILTIN_H_
#include "op.h"
namespace tvm {
namespace tl {
/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
* CuTensorMap* CreateTMADescriptorOp(data_type, rank, global_addr, global_shape...,
* global_stride..., smem_box..., smem_stride..., interleave, swizzle, l2_promotion, oob_fill)
*
*/
const Op& CreateTMADescriptorOp();
/*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load
*
* CuTensorMap* CreateTMAIm2ColDescriptorOp(data_type, rank, global_addr, global_shape...,
* global_stride..., elem_stride..., lower_corner..., upper_corner..., smme_box_pixel, smem_box_channel,
* interleave, swizzle, l2_promotion, oob_fill)
*
*/
const Op& CreateTMAIm2ColDescriptorOp();
/*!
* \brief Create a list of mbarrier with num_threads
*
* GetMBarrier(num_threads0, num_threads1, ...)
*
*/
const Op& CreateListofMBarrierOp();
/*!
* \brief Get the mbarrier with barrier_id
*
* int64_t* GetMBarrier(barrier_id)
*
*/
const Op& GetMBarrierOp();
/*!
* \brief tvm intrinsics for loading data from global tensor descriptor to shared memory
*
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
*
*/
const Op& TMALoadOp();
/*!
* \brief tvm intrinsics for loading image from global tensor to columns in shared memory
*
* TMALoadOp(descriptor, mbarrier, smem_data, coord_0, coord_1, ..., image_offset, ...)
*
*/
const Op& TMALoadIm2ColOp();
/*!
* \brief tvm intrinsics for storing data from shared memory to global tensor descriptor
*
* TMAStoreOp(descriptor, smem_data, coord_0, coord_1, ...)
*
*/
const Op& TMAStoreOp();
/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
*
* MBarrierWaitParity(mbarrier, parity)
*
*/
const Op& MBarrierWaitParity();
/*!
* \brief tvm intrinsics for mbarrier expect tx
*
* MBarrierExpectTX(mbarrier, transaction_bytes)
*
*/
const Op& MBarrierExpectTX();
/*!
* \brief tvm intrinsics for ldmatrix
*
* LDMatrixOp(transposed, num, shared_addr, local_addr)
*
*/
const Op& LDMatrixOp();
/*!
* \brief tvm intrinsics for stmatrix
*
* LDMatrixOp(transposed, num, shared_addr, int32_values...)
*
*/
const Op& STMatrixOp();
/*!
* \brief Pack two b16 value into a b32 value
*
* int32 PackB16Op(b16_value, b16_value)
*
*/
const Op& PackB16Op();
/*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads
*
* SyncThreadsPartialOp(num_partial_threads or mbarrier)
*
*/
const Op& SyncThreadsPartialOp();
/*!
* \brief Issue a shared memory fence for async operations
*
* FenceProxyAsync()
*
*/
const Op& FenceProxyAsyncOp();
/*!
* \brief Set reg hint for warp-specialized branched
*
* SetMaxNRegInc(num_reg, is_inc)
*
*/
const Op& SetMaxNReg();
/*!
* \brief Wait the previous wgmma to finish
*
* WaitWgmma(num_mma)
*
*/
const Op& WaitWgmma();
/*!
* \brief tvm intrinsic for amd matrix core mfma instructions.
*
* void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL const Op &tvm_mfma();
/*!
* \brief tvm intrinsic for storing the result of AMD MFMA into a destination
* pointer.
*
* There is no real instruction that does that, but we want to hide
* details of complex index manipulation behind this intrinsic to simplify TIR
* lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does.
*
* void tvm_mfma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr
* src_offset, Var dst_stride);
*/
TVM_DLL const Op &tvm_mfma_store();
/*!
* \brief tvm intrinsic for amd rdna matrix core instructions.
*
* void tvm_rdna_wmma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index);
*/
TVM_DLL const Op &tvm_rdna_wmma();
/*!
* \brief tvm intrinsic for storing the result of AMD RDNA WMMA into a
* destination pointer.
*
* There is no real instruction that does that, but we want to hide
* details of complex index manipulation behind this intrinsic to simplify TIR
* lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does.
*
* void tvm_rdna_wmma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr
* src_offset, Var dst_stride);
*/
TVM_DLL const Op &tvm_rdna_wmma_store();
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BUILTIN_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/bulk_copy.cc
* \brief Bulk copy operator.
*
*/
#include "bulk_copy.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../target/cuda.h"
#include "builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
static int to_CUtensorMapDataType(DataType dtype) {
CUtensorMapDataType tp;
if (dtype.is_float()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_int()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_INT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_INT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else if (dtype.is_uint()) {
switch (dtype.bits()) {
case 64:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT64;
break;
case 32:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
case 16:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 8:
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
default:
ICHECK(0) << dtype;
}
} else {
ICHECK(0) << dtype;
}
return static_cast<int>(tp);
}
template <typename T>
static Array<T> ReverseArray(Array<T> array) {
return Array<T>{array.rbegin(), array.rend()};
}
Stmt Copy::LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
if (!TargetIsHopper(T.target)) return Stmt();
bool is_load;
if (src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared")) {
// Use the Hopper TMA bulk copy instructions
is_load = true;
} else if (dst.scope() == "global" && (src.scope() == "shared.dyn" || src.scope() == "shared")) {
is_load = false;
} else {
return Stmt();
}
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
Layout shared_layout;
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
}
if (T.layout_map.count(global_tensor)) {
ICHECK(T.layout_map.count(global_tensor) == 0) << "Cannot support global layout.";
}
TMADesc desc;
// Verify copy rank
desc.rank = global_tensor->shape.size();
ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank;
// Verify datatype
ICHECK(global_tensor->dtype == shared_tensor->dtype);
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
// Global Tensor Shape and Stride
auto global_range = is_load ? src_range : dst_range;
desc.global_addr = global_tensor->data;
desc.global_shape = ReverseArray(global_tensor->shape);
Array<PrimExpr> global_coords = ReverseArray(global_range.Map([](Range r) { return r->min; }));
if (!global_tensor->strides.empty()) {
desc.global_stride = ReverseArray(global_tensor->strides);
} else {
// Create stride from shape
PrimExpr stride = 1;
desc.global_stride.reserve(desc.rank);
for (size_t i = 0; i < desc.rank; i++) {
desc.global_stride.push_back(stride);
stride *= desc.global_shape[i];
}
}
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride =
desc.global_stride.Map([&](PrimExpr e) { return e * global_tensor->dtype.bytes(); });
// Smem Box
desc.smem_box = ReverseArray(global_range.Map([](Range r) { return r->extent; }));
desc.smem_stride = Array<PrimExpr>(desc.rank, PrimExpr(1));
// L2 & OOB
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
// Detect smem layout
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
if (!shared_layout.defined()) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else {
ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout.";
auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, makeHalfBankSwizzleLayout(*stride, *continuous,
shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(
shared_layout,
makeFullBankSwizzleLayout(*stride, *continuous, shared_tensor->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
}
}
auto inner_box_dim = as_const_int(desc.smem_box[0]);
ICHECK(inner_box_dim != nullptr);
int instruction_dim = *inner_box_dim;
if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B)) {
instruction_dim = 64 / src->dtype.bytes();
} else if (desc.swizzle == static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B)) {
instruction_dim = 128 / src->dtype.bytes();
}
ICHECK((*inner_box_dim) % instruction_dim == 0);
desc.smem_box.Set(0, PrimExpr(instruction_dim));
Call create_descriptor = Call(DataType::Handle(), CreateTMADescriptorOp(), desc.EncodeCallArgs());
Array<PrimExpr> args;
args.reserve(desc.rank + 3);
args.push_back(create_descriptor);
if (is_load) args.push_back(0); // mbarrier id placeholder
auto op = is_load ? TMALoadOp() : TMAStoreOp();
Stmt tma_copy;
if ((*inner_box_dim) != instruction_dim) {
Var loop_var("i");
int loop_extent = (*inner_box_dim) / instruction_dim;
PrimExpr total_elements = 1;
for (auto e : desc.smem_box) total_elements *= e;
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1, DataType::Handle(), 1,
total_elements * loop_var, total_elements);
args.push_back(shared_addr);
global_coords.Set(0, global_coords[0] + instruction_dim * loop_var);
for (auto coord : global_coords) args.push_back(coord);
tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled,
Evaluate(Call(DataType::Handle(), op, args)));
} else {
PrimExpr shared_addr = shared_tensor.access_ptr(is_load ? 2 : 1);
args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord);
tma_copy = Evaluate(Call(DataType::Handle(), op, args));
}
tma_copy = IfThenElse(EQ(T.thread_var, 0), tma_copy);
return tma_copy;
}
Array<PrimExpr> TMADesc::EncodeCallArgs() const {
Array<PrimExpr> args;
args.reserve(rank * 4 + 7);
args.push_back(data_type);
args.push_back(static_cast<int>(rank));
args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e);
for (auto e : global_stride) args.push_back(e);
for (auto e : smem_box) args.push_back(e);
for (auto e : smem_stride) args.push_back(e);
args.push_back(interleave);
args.push_back(swizzle);
args.push_back(l2_promotion);
args.push_back(oob_fill);
return args;
}
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
Conv2DIm2ColOp::Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
nhw_step = args[2];
c_step = args[3];
kernel = args[4].as<IntImm>().value()->value;
stride = args[5].as<IntImm>().value()->value;
dilation = args[6].as<IntImm>().value()->value;
padding = args[7].as<IntImm>().value()->value;
}
Stmt Conv2DIm2ColOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK(TargetIsHopper(T.target));
ICHECK(src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared"));
ICHECK(src->shape.size() == 4);
ICHECK(dst->shape.size() == 2);
ICHECK(src->dtype == dst->dtype);
Layout shared_layout;
if (T.layout_map.count(dst)) {
shared_layout = T.layout_map[dst];
}
TMAIm2ColDesc desc;
desc.rank = src->shape.size();
desc.data_type = to_CUtensorMapDataType(src->dtype);
desc.global_addr = src->data;
desc.global_shape = ReverseArray(src->shape);
if (!src->strides.empty()) {
desc.global_stride = ReverseArray(src->strides);
} else {
// Create stride from shape
PrimExpr stride = 1;
desc.global_stride.reserve(desc.rank);
for (size_t i = 0; i < desc.rank; i++) {
desc.global_stride.push_back(stride);
stride *= desc.global_shape[i];
}
}
// The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes
desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { return e * src->dtype.bytes(); });
desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding};
desc.smem_box_pixel = Downcast<IntImm>(dst->shape[0])->value;
desc.smem_box_channel = Downcast<IntImm>(dst->shape[1])->value;
desc.l2_promotion = static_cast<int>(CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
desc.oob_fill = static_cast<int>(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
desc.interleave = static_cast<int>(CU_TENSOR_MAP_INTERLEAVE_NONE);
if (!shared_layout.defined()) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_NONE);
} else {
ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout.";
auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout,
makeHalfBankSwizzleLayout(*stride, *continuous, dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_64B);
} else if (StructuralEqual()(shared_layout, makeFullBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
desc.swizzle = static_cast<int>(CU_TENSOR_MAP_SWIZZLE_128B);
} else {
ICHECK(0) << "Cannot detect TMA layout.";
}
}
Call create_desc = Call(DataType::Handle(), CreateTMAIm2ColDescriptorOp(), desc.EncodeCallArgs());
Array<PrimExpr> global_coords; // c, w, h, n
Array<PrimExpr> image_offset; // w, h
global_coords.reserve(desc.rank);
ICHECK(analyzer->CanProveEqual(FloorMod(desc.global_shape[0], desc.smem_box_channel), 0))
<< "Currently can only support divisible channel case";
global_coords.push_back(FloorMod(c_step * desc.smem_box_channel, desc.global_shape[0]));
image_offset.push_back(
dilation * FloorMod(FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0]), kernel));
image_offset.push_back(dilation *
FloorDiv(c_step * desc.smem_box_channel, desc.global_shape[0] * kernel));
PrimExpr h_dim = FloorDiv(src->shape[1] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1;
PrimExpr w_dim = FloorDiv(src->shape[2] + 2 * padding - (kernel - 1) * dilation - 1, stride) + 1;
global_coords.push_back(stride * FloorMod(nhw_step * desc.smem_box_pixel, w_dim) - padding);
global_coords.push_back(
stride * FloorMod(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim), h_dim) - padding);
global_coords.push_back(FloorDiv(nhw_step * desc.smem_box_pixel, w_dim * h_dim));
Array<PrimExpr> args;
args.reserve(desc.rank * 2 + 1);
args.push_back(create_desc);
args.push_back(0); // mbar placeholder
auto dst_buffer = T.buffer_remap.count(dst) ? T.buffer_remap[dst] : dst;
auto shared_addr = dst_buffer.access_ptr(2);
args.push_back(shared_addr);
for (auto coord : global_coords) args.push_back(coord);
for (auto offset : image_offset) args.push_back(offset);
Stmt tma_copy =
IfThenElse(EQ(T.thread_var, 0), Evaluate(Call(DataType::Handle(), TMALoadIm2ColOp(), args)));
return tma_copy;
}
Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
Array<PrimExpr> args;
args.reserve(rank * 5 + 5);
args.push_back(data_type);
args.push_back(static_cast<int>(rank));
args.push_back(global_addr);
for (auto e : global_shape) args.push_back(e);
for (auto e : global_stride) args.push_back(e);
for (auto e : elem_stride) args.push_back(e);
for (auto e : lower_corner) args.push_back(e);
for (auto e : upper_corner) args.push_back(e);
args.push_back(smem_box_pixel);
args.push_back(smem_box_channel);
args.push_back(interleave);
args.push_back(swizzle);
args.push_back(l2_promotion);
args.push_back(oob_fill);
return args;
}
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(8)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/bulk_copy.h
* \brief Bulk copy operator.
*
*/
#ifndef TVM_TL_OP_BULK_COPY_H_
#define TVM_TL_OP_BULK_COPY_H_
#include "elem.h"
namespace tvm {
namespace tl {
using namespace tir;
struct TMADesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride;
Array<PrimExpr> smem_box, smem_stride;
PrimExpr global_addr;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
DataType cuTensorMapType();
struct TMAIm2ColDesc {
size_t rank;
int data_type;
Array<PrimExpr> global_shape, global_stride, elem_stride; // rank
Array<PrimExpr> lower_corner, upper_corner; // rank - 2
PrimExpr global_addr;
int smem_box_pixel, smem_box_channel;
int swizzle;
int interleave;
int oob_fill;
int l2_promotion;
Array<PrimExpr> EncodeCallArgs() const;
};
class Conv2DIm2ColOp : public Operator {
public:
Conv2DIm2ColOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
static const Op& Get();
private:
Buffer src, dst;
int stride, padding, dilation, kernel;
PrimExpr nhw_step, c_step;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_BULK_COPY_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/elem.cc
*
* Define elment-wise operators.
*/
#include "elem.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "../transform/common/loop_fusion_utils.h"
#include "builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
Array<Range> rgs[2];
Buffer bf[2];
for (int i = 0; i < 2; i++) {
auto expr = args[i];
auto call = expr.as<CallNode>();
ICHECK(call);
auto region = RegionOp(call->args, vmap);
rgs[i] = region.GetRanges();
bf[i] = region.GetBuffer();
}
std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
if (args.size() >= 3){
coalesced_width = Downcast<IntImm>(args[2]);
}
}
Array<IterVar> Copy::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent)) continue;
Var var = Var(std::string{char('i' + idx)});
idx++;
loop_vars.push_back({Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar>& ivs, int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
}
ICHECK(idx == ivs.size());
return indices;
}
PrimExpr Copy::MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs,
Array<PrimExpr> extents, int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent)) continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++) cond = And(cond, cond_list[i]);
return cond;
}
}
For Copy::MakeSIMTLoop(arith::Analyzer* analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom);
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
PrimExpr value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype) value = Cast(dst->dtype, value);
if (src_predicate.defined()) value = if_then_else(src_predicate, value, make_zero(dst->dtype));
Stmt body = BufferStore(dst, value, dst_indices);
if (dst_predicate.defined()) body = IfThenElse(dst_predicate, body);
for (int i = loop_vars.size() - 1; i >= 0; i--) {
Map<String, ObjectRef> annotations = {};
if (coalesced_width.defined()){
annotations.Set("coalesced_width", coalesced_width);
}
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body, NullOpt, annotations);
}
return Downcast<For>(body);
}
Stmt Copy::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined()) return ldsm_stmt;
Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
if (bulk_copy_stmt.defined()) return bulk_copy_stmt;
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
auto par_op = std::make_unique<ParallelOp>(fused_loop);
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap}, InferLevel::kFree);
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop);
}
return vectorized_thread_loop;
}
Stmt Copy::LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const {
// Check buffer scope
bool is_ldmatrix;
if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
dst.scope() == "local.fragment") {
is_ldmatrix = true;
} else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
src.scope() == "local.fragment") {
is_ldmatrix = false;
} else {
return Stmt();
}
// Check no predicates
Array<IterVar> loop_vars = MakeIterVars();
if (loop_vars.size() < 2) return Stmt();
for (const auto& iv : loop_vars) analyzer->Bind(iv->var, iv->dom);
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
if (src_predicate.defined() || dst_predicate.defined()) return Stmt();
Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Array<PrimExpr> local_indices_transformed = local_layout->Forward(local_indices);
local_tensor = T.buffer_remap[local_tensor];
// currently only support 1-d case
if (local_layout->OutputDim() != 1) return Stmt();
Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
Layout shared_layout;
if (T.buffer_remap.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
shared_indices_transformed = shared_layout->Forward(shared_indices);
}
// Check local_layout follows 8x8 layout
bool is_transposed;
IterVar col_var = loop_vars[loop_vars.size() - 1];
IterVar row_var = loop_vars[loop_vars.size() - 2];
PrimExpr local_layout_thread_map =
FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
PrimExpr matrix_8x8_thread_map =
makeGemmFragment8x8()->ForwardThread({FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr matrix_8x8_thread_map_trans = makeGemmFragment8x8Transposed()->ForwardThread(
{FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
PrimExpr local_indices_flattened = local_tensor.OffsetOf(local_indices_transformed).back();
if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, col_var->var, col_var->dom->extent, 2,
analyzer)) {
is_transposed = false;
} else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans, local_layout_thread_map) &&
IndiceCanVectorize(local_indices_flattened, row_var->var, row_var->dom->extent, 2,
analyzer)) {
is_transposed = true;
} else {
return Stmt();
}
// Check shared_layout is 16 bytes continuous
if (shared_tensor->dtype.bytes() != 2) return Stmt();
PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices_transformed).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var, loop_vars.back()->dom->extent, 8,
analyzer))
return Stmt();
// Can only support local_range to be a full range
for (size_t i = 0; i < dst_range.size(); i++) {
if (!is_zero(dst_range[i]->min) ||
!analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
return Stmt();
}
// Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
PrimExpr extent = local_tensor->shape[0];
int num = 1;
if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
num = 4;
else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
num = 2;
Array<PrimExpr> args;
const Op& op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);
// Create shared address with regard to local address
// if not transpose
// coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
// if transpose
// coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread % 8 / 2)
Var local_iter("i");
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords =
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords =
inv->Forward({local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
shared_coords.pop_back(); // remove rep
if (shared_layout.defined()) shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1, shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
args.push_back(shared_addr);
if (is_ldmatrix) {
// Can only support same dtype for ldmatrx
if (local_tensor->dtype != shared_tensor->dtype) return Stmt();
PrimExpr local_addr =
local_tensor.access_ptr(2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
args.push_back(local_addr);
} else {
for (int i = 0; i < num; i++) {
PrimExpr value0 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
PrimExpr value1 = BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
if (local_tensor->dtype != shared_tensor->dtype) {
value0 = Cast(shared_tensor->dtype, value0);
value1 = Cast(shared_tensor->dtype, value1);
}
PrimExpr value_packed = Call(DataType::Int(32), PackB16Op(), {value0, value1});
args.push_back(value_packed);
}
}
auto body = Evaluate(Call(DataType::Handle(), op, args));
For for_node = For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
for_node = LoopPragmaUnroll(for_node);
return for_node;
}
LayoutMap Copy::InferLayout(const LayoutInferArgs& T, InferLevel level) {
// Use parallel op to infer the layout
if (par_op_ == nullptr) {
arith::Analyzer analyzer;
par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
}
return par_op_->InferLayout(T, level);
}
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
dst = vmap[GetVarFromAccessPtr(args[0])];
if (args[1]->dtype != dst->dtype) {
value = Cast(dst->dtype, args[1]);
} else {
value = args[1];
}
}
For Fill::MakeSIMTLoop(arith::Analyzer* analyzer) const {
int ndim = dst->shape.size();
Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices;
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)});
loop_vars.push_back({Range(0, dst->shape[i]), var, IterVarType::kDataPar});
dst_indices.push_back(var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, ForKind::kParallel, body);
}
return Downcast<For>(body);
}
Stmt Fill::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree);
par_op->InferLayout({T.target, T.block_size, T.layout_map}, InferLevel::kFree);
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop);
}
return vectorized_thread_loop;
} else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop);
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
}
}
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_REGISTER_TL_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/elem.h
* \brief Define elment-wise operators.
*
*/
#ifndef TVM_TL_OP_ELEM_H_
#define TVM_TL_OP_ELEM_H_
#include "op.h"
#include "parallel.h"
namespace tvm {
namespace tl {
using namespace tir;
class Copy : public Operator {
public:
Copy(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
protected:
Stmt LowerBulkCopy(const LowerArgs& T, arith::Analyzer* analyzer) const;
Stmt LowerLDSMCopy(const LowerArgs& T, arith::Analyzer* analyzer) const;
For MakeSIMTLoop(arith::Analyzer* analyzer) const;
Array<IterVar> MakeIterVars() const;
// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
Array<PrimExpr> MakeIndices(const Array<IterVar>& ivs, int src_dst) const;
PrimExpr MakePredicate(arith::Analyzer* analyzer, const Array<IterVar>& ivs,
Array<PrimExpr> extents, int src_dst) const;
Array<PrimExpr> args_;
Buffer src, dst;
Array<Range> src_range, dst_range;
IntImm coalesced_width;
std::unique_ptr<ParallelOp> par_op_;
};
class Fill : public Operator {
public:
Fill(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
static const Op& Get();
private:
For MakeSIMTLoop(arith::Analyzer* analyzer) const;
tir::Buffer dst;
PrimExpr value;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_ELEM_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/gemm.cc
*
* Define gemm operator.
*/
#include "gemm.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../target/utils.h"
namespace tvm {
namespace tl {
using namespace tir;
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}
Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])];
B = vmap[GetVarFromAccessPtr(args[1])];
C = vmap[GetVarFromAccessPtr(args[2])];
trans_A = args[3].as<Bool>().value();
trans_B = args[4].as<Bool>().value();
M = args[5].as<IntImm>().value()->value;
N = args[6].as<IntImm>().value()->value;
K = args[7].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[8].as<IntImm>().value()->value);
if (args.size() > 9) {
kPack = args[9].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
}
std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target) const {
int m_warp = 1, n_warp = 1;
if (TargetIsHopper(target)) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
if (this->policy == GemmWarpPolicy::kFullRow || this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) {
m_warp = 4;
n_warp = num_warps / 4;
ICHECK(this->N % n_warp == 0);
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) {
n_warp = num_warps;
ICHECK(this->N % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kSquare) {
auto factors = toPrimeFactors(num_warps);
for (int factor : factors) {
bool M_divisible = (this->M % (factor * m_warp)) == 0;
bool N_divisible = (this->N % (factor * n_warp)) == 0;
if (M_divisible && N_divisible) {
if (this->M / m_warp >= this->N / n_warp)
m_warp *= factor;
else
n_warp *= factor;
} else if (M_divisible) {
m_warp *= factor;
} else if (N_divisible) {
n_warp *= factor;
} else {
ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
<< " with num_warps " << num_warps;
}
}
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
// TODO: perform more checks here
return {m_warp, n_warp};
}
Stmt Gemm::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}
ICHECK(T.block_size % warp_size == 0);
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
std::stringstream ss;
std::string op_name = "tl::gemm_ss";
if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
}
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
if (TargetIsCDNA(T.target)) {
// for cdna gemm, we need to specify kPack
ss << ", " << kPack;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1));
new_args.push_back(B_buffer.access_ptr(1));
new_args.push_back(C_buffer.access_ptr(3));
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
return Evaluate(new_call);
}
LayoutMap Gemm::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (completed_) return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
if (TargetIsVolta(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmVoltaABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
true, trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmVoltaFragmentA(M, N, K, M / warp_m, N / warp_n));
} else {
ICHECK(0);
}
ICHECK(B.scope() == "shared" || B.scope() == "shared.dyn");
results.Set(B, makeGemmVoltaABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false);
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
} else {
ICHECK(0);
}
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
results.Set(A, makeGemmABLayout(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
A->dtype.bits(), trans_A ? 1 : 2));
} else {
ICHECK(trans_A == false);
results.Set(A, makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits()));
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
results.Set(B, makeGemmABLayout(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
B->dtype.bits(), trans_B ? 2 : 1));
} else {
ICHECK(0) << "WGMMA only support B in shared.";
}
} else if (TargetIsCDNA(T.target)) {
ICHECK(trans_B == true) << "Currently only support Transpose B for CDNA";
const int warp_size = 64;
auto [warp_m, warp_n] = ComputeWarpPartition(T.block_size / warp_size, T.target);
auto fragment = makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(A->shape[0]), *as_const_int(A->shape[1]),
A->dtype.bits(), kPack);
results.Set(A, shared_layout);
} else if (A.scope() == "local.fragment") {
results.Set(A, makeGemmFragmentACDNA(M, N, K, M / warp_m, N / warp_n, trans_A));
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
// Make Linear Memory Access Layout
// auto shared_layout =
// makeGemmLayoutLinear(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]));
// Make Swizzle or Pad Layout
auto shared_layout = makeGemmABLayoutCDNA(*as_const_int(B->shape[0]), *as_const_int(B->shape[1]),
B->dtype.bits(), kPack);
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
results.Set(B, makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n));
} else {
ICHECK(0);
}
} else {
ICHECK(0) << "Not supported " << T.target->str();
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(Gemm, gemm)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/gemm.h
* \brief Define gemm operator.
*
*/
#ifndef TVM_TL_OP_GEMM_H_
#define TVM_TL_OP_GEMM_H_
#include "op.h"
namespace tvm {
namespace tl {
using namespace tir;
class Gemm : public Operator {
public:
Gemm(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
private:
std::pair<int, int> ComputeWarpPartition(int num_warps, Target target) const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C;
bool trans_A, trans_B;
int M, N, K;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
bool completed_ = false;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/op.cc
*
* Define operators usd in tile library.
*/
#include "op.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
TIR_REGISTER_TL_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap) {
auto op_map = Op::GetAttrMap<OpBuilderFunc>("TLOpBuilder");
Op op = call->op.as<Op>().value();
if (op_map.count(op)) {
Operator* ptr = static_cast<Operator*>(op_map[op](call->args, vmap));
ICHECK(ptr != nullptr);
return std::unique_ptr<Operator>(ptr);
}
return nullptr;
}
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap) {
if (stmt.as<Evaluate>() && stmt.as<EvaluateNode>()->value.as<CallNode>()) {
auto call = stmt.as<EvaluateNode>()->value.as<CallNode>();
return ParseOperator(GetRef<Call>(call), vmap);
}
return nullptr;
}
Var GetVarFromAccessPtr(const PrimExpr& expr) {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
return GetRef<Var>(var);
}
RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t n = args.size();
size_t ndim = n - 2;
auto load = args[0].as<BufferLoadNode>();
ICHECK(load);
ICHECK(load->indices.size() == ndim);
buffer_ = load->buffer;
access_mask_ = static_cast<int>(*as_const_int(args[1]));
for (size_t i = 0; i < ndim; i++) {
PrimExpr min = load->indices[i];
PrimExpr extent = args[2 + i];
ranges_.push_back(Range::FromMinExtent(min, extent));
}
}
bool RegionOp::IsFullRegion() const {
for (size_t i = 0; i < ranges_.size(); i++) {
if (!is_zero(ranges_[i]->min)) return false;
if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false;
}
return true;
}
Stmt Operator::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK(0) << "Not Implemented Lower method.";
return Evaluate(0);
}
Stmt Operator::Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const { return {}; }
LayoutMap Operator::InferLayout(const LayoutInferArgs& T, InferLevel level) { return {}; }
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/op.h
* \brief Tile library operations.
*
*/
#ifndef TVM_TL_OP_OP_H_
#define TVM_TL_OP_OP_H_
#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/target.h>
#include <tvm/tir/buffer.h>
#include "../layout/layout.h"
namespace tvm {
namespace tl {
using namespace tir;
using AddWorkspaceCallback = std::function<PrimExpr(int, DataType)>;
using LayoutMap = Map<Buffer, Layout>;
using BufferMap = Map<Var, Buffer>;
using OpBuilderFunc = TypedPackedFunc<void*(Array<PrimExpr>, BufferMap)>;
#define TIR_REGISTER_TL_OP(Entry, OpName) \
const Op& Entry::Get() { \
static const Op& op = Op::Get("tl." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> a, BufferMap b) { return (void*)(new Entry(a, b)); })
enum class InferLevel {
kFree = 0,
kCommon = 1,
kStrict = 2,
};
struct LowerArgs {
Target target;
size_t block_size;
Var thread_var;
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
};
struct LayoutInferArgs {
Target target;
size_t block_size;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
};
struct CanonializeArgs {
Target target;
};
class Operator {
public:
virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs& T, arith::Analyzer* analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level);
virtual ~Operator() = default;
};
class RegionOp : public Operator {
public:
RegionOp(Array<PrimExpr> args, BufferMap vmap);
static const Op& Get();
const Buffer& GetBuffer() const { return buffer_; }
const Array<Range>& GetRanges() const { return ranges_; }
int GetAccessMask() const { return access_mask_; }
bool IsFullRegion() const;
private:
Buffer buffer_;
Array<Range> ranges_;
int access_mask_;
};
Var GetVarFromAccessPtr(const PrimExpr& expr);
std::unique_ptr<Operator> ParseOperator(Call call, BufferMap vmap);
std::unique_ptr<Operator> ParseOperator(Stmt stmt, BufferMap vmap);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_OP_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file op/parallel.cc
* \brief Define Parallel for operator
*/
#include "parallel.h"
#include <tvm/tir/op.h>
#include "../layout/utils.h"
#include "../target/utils.h"
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
namespace tvm {
namespace tl {
using namespace tir;
namespace attr {
/*! \brief Mark that how the loop is vectorized. */
constexpr const char *coalesced_width = "coalesced_width";
}
class IfBufferRemapLoopGenerator : public StmtExprMutator {
public:
static For run(Stmt stmt, Map<Buffer, Buffer> buffer_remap,
Map<Buffer, Layout> layout_map) {
IfBufferRemapLoopGenerator generator(buffer_remap, layout_map);
return Downcast<For>(generator(std::move(stmt)));
}
private:
IfBufferRemapLoopGenerator(Map<Buffer, Buffer> buffer_remap, Map<Buffer, Layout> layout_map)
: buffer_remap_(buffer_remap), layout_map_(layout_map) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
if (buffer_remap_.count(load->buffer)) {
auto new_indices = layout_map_[load->buffer]->Forward(load->indices);
auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, new_indices);
}
return load;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (buffer_remap_.count(store->buffer)) {
auto new_indices = layout_map_[store->buffer]->Forward(store->indices);
auto new_buffer = buffer_remap_[store->buffer];
return BufferStore(new_buffer, store->value, new_indices);
}
return store;
}
Map<Buffer, Buffer> buffer_remap_;
Map<Buffer, Layout> layout_map_;
};
void ParallelLoopNestVisitor::VisitStmt_(const ForNode* op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode* op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
p->buffer_is_write_.insert(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode* op) {
if (op->buffer.scope() == "local.fragment") {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and " << p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
}
StmtExprVisitor::VisitExpr_(op);
}
ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); }
bool ParallelOp::IsCommonAccessIndice(const Buffer& buffer) const {
auto common_indice = loop_vars_.Map([](const auto& iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
}
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (loop_layout_.defined()) return {};
if (level == InferLevel::kStrict) return {};
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
for (const auto& [buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto frag = T.layout_map[buffer].as<Fragment>().value();
if (buffer_is_write_.count(buffer))
source_buffer = buffer;
else
read_source_buffer = buffer;
}
}
auto compute_loop_layout_from_buffer = [&](const Buffer& buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
if (IsCommonAccessIndice(buffer)) {
return src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar);
PrimExpr loop_var_to_thread = src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter);
}
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// Loop don't need to be replicated.
if (!is_one(loop_layout_->ReplicateExtent())) loop_layout_ = loop_layout_->DeReplicate();
// if still has replication, add a condition
if (!is_one(loop_layout_->ReplicateExtent())) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else {
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
// Check if coalesced_width is defined
if (auto coalesced_width = root_->annotations.Get(tl::attr::coalesced_width)) {
if (const auto* imm = coalesced_width.as<IntImmNode>()) {
int expected = imm->value;
// Verify that vector_size is divisible by expected
if (vector_size % expected != 0) {
LOG(FATAL) << "Vector size " << vector_size << " is not divisible by coalesced width "
<< expected;
}
vector_size = expected;
} else {
LOG(FATAL) << "coalesced_width should be an IntImmNode.";
}
}
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else {
return {};
}
// Step 2: Check that the loop's partition can correctly align with all source fragment
for (const auto& [buffer, _] : indice_map_) {
if (T.layout_map.count(buffer)) {
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// TODO: Add thread checks for replicated cases
// need to wildcard match the rhs with lhs
if (!is_one(loop_layout_->ReplicateExtent()) || !is_one(fragment->ReplicateExtent()))
continue;
auto vars = loop_vars_.Map([](const IterVar& iv) { return PrimExpr(iv->var); });
auto lhs = loop_layout_->ForwardThread(vars, NullOpt);
auto rhs = fragment->ForwardThread(indice_map_[buffer], NullOpt);
auto diff = analyzer_.Simplify(lhs - rhs);
ICHECK(is_zero(diff)) << "Layout infer conflict for " << buffer << " " << source_buffer
<< "\nLHS = " << lhs << "\nRHS = " << rhs;
}
}
// Step 3: Infer other fragment's layout from the loop's partition
LayoutMap results;
for (const auto& [buffer, _] : indice_map_) {
if (!T.layout_map.count(buffer)) results.Set(buffer, CompleteBufferFragment(buffer));
}
return results;
}
Optional<PrimExpr> ParallelOp::GetPredicate(Var thread_var) const {
if (predicate_.defined()) {
return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}});
} else {
return NullOpt;
}
}
Fragment ParallelOp::CompleteBufferFragment(const Buffer& buffer) {
ICHECK(loop_layout_.defined());
if (IsCommonAccessIndice(buffer)) return loop_layout_;
PrimExpr rep_b =
MakeFlattenedExpression(DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
PrimExpr indice_rep_extent = ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd.push_back(InputPlaceholder(i));
}
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv->Forward(fwd), FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent, NullOpt)
->CondenseReplicateVar();
}
} // namespace tl
} // namespace tvm
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/parallel.h
* \brief Infer layout from ops and parallel for
*/
#ifndef TVM_TL_OP_PARALLEL_H_
#define TVM_TL_OP_PARALLEL_H_
#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>
#include "../layout/layout.h"
#include "op.h"
namespace tvm {
namespace tl {
using namespace tir;
class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor {
private:
ParallelLoopNestVisitor(ParallelOp* op) : p(op){};
void VisitStmt_(const ForNode* op) final;
void VisitStmt_(const BufferStoreNode* op) final;
void VisitExpr_(const BufferLoadNode* op) final;
ParallelOp* p;
friend class ParallelOp;
};
class ParallelOp : public Operator {
public:
ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
Optional<PrimExpr> GetPredicate(Var thread_var) const;
private:
Fragment CompleteBufferFragment(const Buffer& buffer);
bool IsCommonAccessIndice(const Buffer& buffer) const;
void AddPredicate(PrimExpr expr) {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
For root_;
ParallelLoopNestVisitor V;
Map<Buffer, Array<PrimExpr>> indice_map_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
Array<IterVar> loop_vars_;
Fragment loop_layout_;
mutable arith::Analyzer analyzer_;
Optional<PrimExpr> predicate_;
friend class ParallelLoopNestVisitor;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_PARALLEL_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/reduce.cc
*
* Define reduce operator.
*/
#include "reduce.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../layout/utils.h"
#include "../transform/loop_partition.h"
namespace tvm {
namespace tl {
using namespace tir;
ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
src = vmap[GetVarFromAccessPtr(args[0])];
dst = vmap[GetVarFromAccessPtr(args[1])];
String reduce_type = args[2].as<StringImm>().value()->value;
dim = args[3].as<IntImm>().value()->value;
if (reduce_type == "sum")
type = ReduceType::kSum;
else if (reduce_type == "abssum")
type = ReduceType::kAbsSum;
else if (reduce_type == "max")
type = ReduceType::kMax;
else if (reduce_type == "min")
type = ReduceType::kMin;
else
ICHECK(0) << "Unknown reduce type: " << reduce_type;
clear = args[4].as<Bool>().value();
}
PrimExpr ReduceOp::MakeInitValue() const {
switch (type) {
case ReduceType::kSum:
return make_zero(dst->dtype);
case ReduceType::kAbsSum:
return make_zero(dst->dtype);
case ReduceType::kMax:
return make_const(dst->dtype, -INFINITY);
case ReduceType::kMin:
return make_const(dst->dtype, INFINITY);
default:
ICHECK(0);
}
}
PrimExpr ReduceOp::MakeReduce(const PrimExpr& a, const PrimExpr& b) const {
PrimExpr lhs = a, rhs = b;
if (lhs->dtype != rhs->dtype) {
rhs = Cast(lhs->dtype, rhs);
}
switch (type) {
case ReduceType::kSum:
return lhs + rhs;
case ReduceType::kAbsSum:
return lhs + Max(rhs, -rhs);
case ReduceType::kMax:
return Max(lhs, rhs);
case ReduceType::kMin:
return Min(lhs, rhs);
default:
ICHECK(0);
return PrimExpr(0);
}
}
std::string ReduceOp::MakeCodegenReducer() const {
switch (type) {
case ReduceType::kSum:
return "tl::SumOp";
case ReduceType::kAbsSum:
return "tl::SumOp";
case ReduceType::kMax:
return "tl::MaxOp";
case ReduceType::kMin:
return "tl::MinOp";
default:
ICHECK(0);
return "";
}
}
Stmt ReduceOp::Lower(const LowerArgs& T, arith::Analyzer* analyzer) const {
ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment")
<< "Reduce for shared memory not implemented.";
auto src_buffer = T.buffer_remap[this->src];
auto dst_buffer = T.buffer_remap[this->dst];
Fragment src_layout = T.layout_map[this->src].as<Fragment>().value();
Fragment dst_layout = T.layout_map[this->dst].as<Fragment>().value();
ICHECK(src_layout->InputDim() == dst_layout->InputDim() + 1);
Array<IterVar> dst_vars;
for (size_t i = 0; i < dst_layout->InputDim(); i++) {
Var var = Var(std::string{char('i' + i)});
dst_vars.push_back(IterVar(Range(0, dst_layout->InputShape()[i]), var, IterVarType::kDataPar));
}
Array<IterVar> src_vars = dst_vars;
src_vars.insert(src_vars.begin() + this->dim, {Range(0, src_layout->InputShape()[this->dim]),
Var("rv"), IterVarType::kDataPar});
Array<PrimExpr> src_indices =
src_layout->Forward(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
Array<PrimExpr> dst_indices =
dst_layout->Forward(dst_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }));
Array<Stmt> stmts;
// make reduce-init stmt
if (this->clear) stmts.push_back(BufferStore(dst_buffer, this->MakeInitValue(), dst_indices));
// make thread-local reduce
Array<PrimExpr> src_indice_compressed;
Array<IterVar> src_var_compressed;
for (size_t i = 0; i < src_layout->OutputDim(); i++) {
PrimExpr expr;
IterVar var;
std::tie(expr, var) =
CompressIterator(src_indices[i], src_vars, src_vars[this->dim]->var, analyzer);
src_indice_compressed.push_back(expr);
src_var_compressed.push_back(var);
}
Stmt reduce_local = BufferStore(dst_buffer,
this->MakeReduce(BufferLoad(dst_buffer, dst_indices),
BufferLoad(src_buffer, src_indice_compressed)),
dst_indices);
for (int i = src_layout->OutputDim() - 1; i >= 0; i--) {
reduce_local =
For(src_var_compressed[i]->var, 0, src_var_compressed[i]->dom->extent, ForKind::kUnrolled,
reduce_local, NullOpt, {{tir::attr::pragma_unroll_explicit, Bool(false)}});
}
stmts.push_back(reduce_local);
// make inter-thread reduce
PrimExpr src_thread =
src_layout->ForwardThread(src_vars.Map([](const auto& iv) { return PrimExpr(iv->var); }), {});
auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer);
for (const auto& iter_split : iter_sum->args) {
auto mark = iter_split->source->source.as<Var>();
ICHECK(mark.defined());
if (mark.value().same_as(src_vars[this->dim]->var)) {
auto scale = as_const_int(iter_split->scale);
auto extent = as_const_int(iter_split->extent);
ICHECK(scale != nullptr && extent != nullptr);
if (*extent == 1) continue;
int reducing_threads = (*extent) * (*scale);
std::stringstream ss;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " << reducing_threads << ", "
<< (*scale) << ">::run";
Array<PrimExpr> thread_reduce_args = {StringImm(ss.str()),
BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype);
thread_reduce_args.push_back(workspace);
}
auto call = Call(dst_buffer->dtype, builtin::call_extern(), thread_reduce_args);
stmts.push_back(BufferStore(dst_buffer, call, dst_indices));
}
}
Stmt reduce_interthread =
BufferStore(dst_buffer, BufferLoad(dst_buffer, dst_indices), dst_indices);
// make the outer spatial loop
Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0];
for (int i = dst_layout->InputDim() - 1; i >= 0; i--) {
body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body);
}
body = PartitionLoop(Downcast<For>(body), T.thread_var, analyzer, dst_layout);
return body;
}
LayoutMap ReduceOp::InferLayout(const LayoutInferArgs& T, InferLevel level) {
if (level >= InferLevel::kStrict) return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
PrimExpr indice_rep_extent = src->shape[dim];
PrimExpr src_rep_extent = src_layout->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent;
Array<PrimExpr> fwd;
for (int i = 0; i < static_cast<int>(src->shape.size()); i++) {
if (i == dim) {
fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent));
} else if (i < dim) {
fwd.push_back(InputPlaceholder(i));
} else if (i > dim) {
fwd.push_back(InputPlaceholder(i - 1));
}
}
auto thd =
src_layout->ForwardThread(fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, NullOpt)->CondenseReplicateVar();
return {{dst, dst_layout}};
}
return {};
}
TIR_REGISTER_TL_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/reduce.h
* \brief Define reduce operator.
*
*/
#ifndef TVM_TL_OP_REDUCE_H_
#define TVM_TL_OP_REDUCE_H_
#include "op.h"
namespace tvm {
namespace tl {
using namespace tir;
class ReduceOp : public Operator {
public:
ReduceOp(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) final;
static const Op& Get();
private:
tir::Buffer src, dst;
int dim;
enum class ReduceType {
kSum,
kAbsSum,
kMax,
kMin,
} type;
bool clear;
PrimExpr MakeInitValue() const;
PrimExpr MakeReduce(const PrimExpr& a, const PrimExpr& b) const;
std::string MakeCodegenReducer() const;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_REDUCE_H_
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/runtime/runtime.h
* \brief Runtime functions.
*
*/
#include "runtime.h"
#include "../target/cuda.h"
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tl {
using namespace runtime;
template <typename T>
static std::string ArrayToStr(const T* ptr, size_t n) {
std::stringstream ss;
ss << "[";
for (size_t i = 0; i < n; i++) {
if (i > 0) ss << ", ";
ss << ptr[i];
}
ss << "]";
return ss.str();
}
struct TensorMapArgs {
CUtensorMap* map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t boxDim[5], elementStrides[5];
CUtensorMapInterleave interleave;
CUtensorMapSwizzle swizzle;
CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill;
static TensorMapArgs Extract(TVMArgs args) {
TensorMapArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(8 + T.tensorRank * 4));
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.boxDim[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
}
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
std::string ToDebugString() {
std::stringstream ss;
ss << "TMA Desc Addr: " << map << std::endl
<< "format " << type << std::endl
<< "dim " << tensorRank << std::endl
<< "gmem_address " << globalAddress << std::endl
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
<< "oobFill " << oobFill << std::endl;
return ss.str();
}
};
// set device api
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled).set_body([](TVMArgs args, TVMRetValue* ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1, T.boxDim,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
struct TensorMapIm2ColArgs {
CUtensorMap* map;
CUtensorMapDataType type;
cuuint32_t tensorRank;
void* globalAddress;
cuuint64_t globalDim[5], globalStride[5];
cuuint32_t elementStrides[5];
int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
cuuint32_t smem_box_channel, smem_box_pixel;
CUtensorMapInterleave interleave;
CUtensorMapSwizzle swizzle;
CUtensorMapL2promotion l2Promotion;
CUtensorMapFloatOOBfill oobFill;
static TensorMapIm2ColArgs Extract(TVMArgs args) {
TensorMapIm2ColArgs T;
int idx = 0;
ICHECK(args.num_args >= 8);
T.map = reinterpret_cast<CUtensorMap*>(static_cast<void*>(args[idx++]));
T.type = static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
T.globalAddress = args[idx++];
ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
ICHECK(args.num_args == static_cast<int>(6 + T.tensorRank * 5));
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank; i++) {
T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxLowerCorner[i] = static_cast<int>(args[idx++]);
}
for (size_t i = 0; i < T.tensorRank - 2; i++) {
T.pixelBoxUpperCorner[i] = static_cast<int>(args[idx++]);
}
T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
T.interleave = static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
T.swizzle = static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
T.l2Promotion = static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
T.oobFill = static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
return T;
}
std::string ToDebugString() {
std::stringstream ss;
ss << "TMA Desc Addr: " << map << std::endl
<< "format " << type << std::endl
<< "dim " << tensorRank << std::endl
<< "gmem_address " << globalAddress << std::endl
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl
<< "pixelBoxLowerCorner " << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< "pixelBoxUpperCorner " << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
<< "oobFill " << oobFill << std::endl;
return ss.str();
}
};
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col).set_body([](TVMArgs args, TVMRetValue* ret) {
TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
CUresult result = cuTensorMapEncodeIm2col(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim, T.globalStride + 1,
T.pixelBoxLowerCorner, T.pixelBoxUpperCorner, T.smem_box_channel, T.smem_box_pixel,
T.elementStrides, T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment