"src/transform/atomicadd_vectorize.h" did not exist on "c5df7938902e68a835d8423fdc08753fb1834a6b"
Unverified Commit 407117e1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce a new layout inference mechanism (#699)



* Implement new free stage layout inference.

* Fix bug

* Make replication upcasting and unnormalizable iterators safe.

* Better handling of updating with more replica

* Remove unnecessary check.

* Fix compilation.

* Fix setup.py.

* Simplify development mode.

* Allow ParallelOp layout when there's already a compatible layout specified

* lint fix

* Add ProveFragmentContains function to validate thread access between small and large fragments

This function checks if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment, ensuring valid access during updates. The implementation includes deriving thread indices, computing logical indices, and verifying thread mappings.

* Update dependencies in requirements files

* Remove 'thefuzz' from requirements-dev.txt
* Specify exact versions for 'torch' and add 'flash_attn' in requirements-test.txt

* Update CI workflow to use SHA256 hash for requirements file

* Update requirements and CI workflow for flash attention

* Removed specific version for 'torch' in requirements-test.txt
* Added installation of 'flash_attn==2.5.8' in CI workflow to ensure compatibility

* Refactor flash attention import handling in examples

* Removed availability checks for 'flash_attn' in multiple example scripts.
* Simplified import statements for 'flash_attn' to ensure consistent usage across examples.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent 87aae294
......@@ -17,6 +17,20 @@ namespace tl {
using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_);
class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor {
......@@ -36,6 +50,14 @@ public:
ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) {
loop_layout_ = other.loop_layout_;
predicate_ = other.predicate_;
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ParallelOp>(*this);
}
Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
......
......@@ -12,6 +12,7 @@
#include <tvm/tir/stmt_functor.h>
#include "../layout/utils.h"
#include "../op/parallel.h"
#include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h"
......@@ -287,7 +288,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict)
return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) {
T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value();
PrimExpr indice_rep_extent = src->shape[dim];
......@@ -310,7 +311,46 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds);
return {{dst, dst_layout}};
if (!T.layout_map.count(dst))
return {{dst, dst_layout}};
else {
// Check if computed layout is compatible with existing: the existing one
// must strictly contains the computed layout
auto orig_dst_layout =
T.layout_map.Get(dst).value().as<Fragment>().value();
ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
Array<PrimExpr> indices;
indices.reserve(dst_layout->InputDim());
arith::Analyzer inner_analyzer;
for (int i = 0; i < dst_layout->InputDim(); ++i) {
auto x = InputPlaceholder(i);
indices.push_back(x);
// should be literal - literal = 0, any analyzer will work
ICHECK(is_zero(inner_analyzer.Simplify(
dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
}
ICHECK(as_const_int(dst_layout->ReplicateExtent()));
ICHECK(as_const_int(src_layout->ReplicateExtent()));
auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
auto src_rep = *as_const_int(src_layout->ReplicateExtent());
if (dst_rep < src_rep ||
!ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
inner_analyzer)) {
std::ostringstream oss;
oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
<< src << "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << orig_dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the "
"layout";
throw LayoutConflictException(oss.str());
}
if (dst_rep > src_rep) {
return {{dst, dst_layout}};
}
}
}
return {};
}
......
......@@ -21,6 +21,10 @@ public:
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ReduceOp>(*this);
}
private:
tir::Buffer src, dst;
int dim;
......@@ -45,6 +49,10 @@ public:
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<CumSumOp>(*this);
}
private:
tir::Buffer src, dst;
int dim;
......
#ifndef TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
#define TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
#include <unordered_map>
#include <vector>
namespace tvm {
namespace tl {
template <typename T> class UnionFind {
public:
void MakeSet(const T &x) {
if (parent_.find(x) == parent_.end()) {
parent_[x] = x;
rank_[x] = 0;
}
}
T Find(const T &x) {
if (parent_[x] != x) {
parent_[x] = Find(parent_[x]); // Path compression
}
return parent_[x];
}
void Union(const T &x, const T &y) {
T x_root = Find(x);
T y_root = Find(y);
if (x_root == y_root)
return;
// Union by rank
if (rank_[x_root] < rank_[y_root]) {
parent_[x_root] = y_root;
} else if (rank_[x_root] > rank_[y_root]) {
parent_[y_root] = x_root;
} else {
parent_[y_root] = x_root;
rank_[x_root]++;
}
}
private:
std::unordered_map<T, T> parent_;
std::unordered_map<T, int> rank_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
......@@ -13,11 +13,13 @@
#include <queue>
#include "../layout/utils.h"
#include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h"
#include "common/loop_parallel_transform_utils.h"
#include "common/union_find.h"
#include "loop_partition.h"
#include "loop_vectorize.h"
#include "runtime/thread_storage_scope.h"
......@@ -60,6 +62,131 @@ public:
BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {}
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size();
// Range check for cur_infer_id
ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid.";
ICHECK_LT(cur_infer_id, num_infer)
<< "cur_infer_id " << cur_infer_id << " is out of range, must be < "
<< num_infer << ".";
// Make sure we can safely access infer_list_[cur_infer_id] and
// thread_var_vec_[cur_infer_id]
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next != nullptr)
<< "infer_list_[" << cur_infer_id << "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent
ICHECK(iter_var.defined())
<< "thread_var_vec_[" << cur_infer_id << "] is not defined.";
ICHECK(iter_var->dom.defined())
<< "iter_var->dom is not defined for infer_list_[" << cur_infer_id
<< "].";
ICHECK(iter_var->dom->extent.defined())
<< "iter_var->dom->extent is not defined for infer_list_["
<< cur_infer_id << "].";
const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
ICHECK(extent_ptr != nullptr)
<< "iter_var->dom->extent is not a constant integer, which is "
"required for layout inference.";
// Run InferLayout
auto updates = next->InferLayout(
LayoutInferArgs{target_, thread_bounds, layout_map}, level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
if (layout_map.count(buffer)) {
// If new layout contains the old one, update map
if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
// Actually this test has been done in ParallelOp::InferLayout
// already. Just do it again to avoid missing implementations in other
// `Operator`s.
auto dst_layout = layout.as<Fragment>().value();
auto src_layout = layout_map[buffer].as<Fragment>().value();
ICHECK(dst_layout->InputDim() == src_layout->InputDim());
Array<PrimExpr> indices;
indices.reserve(dst_layout->InputDim());
arith::Analyzer inner_analyzer;
for (int i = 0; i < dst_layout->InputDim(); ++i) {
auto x = InputPlaceholder(i);
indices.push_back(x);
// should be literal - literal = 0, any analyzer will work
ICHECK(is_zero(inner_analyzer.Simplify(
dst_layout->InputShape()[i] - src_layout->InputShape()[i])));
inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
}
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout);
continue;
}
}
// If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
if (!update_queue)
continue;
// Check if buffer exists in use_list_
if (!use_list_.count(buffer)) {
LOG(WARNING) << "Layout inference failed for buffer " << buffer
<< ". "
<< "The buffer cannot be inferred with current layout "
"inference rules.";
continue;
}
// Push back into BFS queue
for (int idx : use_list_[buffer]) {
ICHECK_GE(idx, 0)
<< "Index in use_list_ for buffer " << buffer << " is negative.";
ICHECK_LT(idx, num_infer)
<< "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << ".";
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
};
void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
const LayoutMap &strict_layout_map, std::queue<int> &q,
std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size();
while (!q.empty()) {
int cur_infer_id = q.front();
q.pop();
// Range check again, just to be safe
ICHECK_GE(cur_infer_id, 0);
ICHECK_LT(cur_infer_id, num_infer);
in_queue[cur_infer_id] = false;
RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q,
in_queue);
}
};
LayoutInferenceResult Run() {
// Basic consistency check: infer_list_ and thread_var_vec_ should have the
// same size
......@@ -94,124 +221,10 @@ public:
q.push(i);
}
auto run_infer_step = [&](int cur_infer_id, InferLevel level,
bool update_queue) {
// Range check for cur_infer_id
ICHECK_GE(cur_infer_id, 0)
<< "cur_infer_id is negative, which is invalid.";
ICHECK_LT(cur_infer_id, num_infer)
<< "cur_infer_id " << cur_infer_id << " is out of range, must be < "
<< num_infer << ".";
// Make sure we can safely access infer_list_[cur_infer_id] and
// thread_var_vec_[cur_infer_id]
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
<< "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent
ICHECK(iter_var.defined())
<< "thread_var_vec_[" << cur_infer_id << "] is not defined.";
ICHECK(iter_var->dom.defined())
<< "iter_var->dom is not defined for infer_list_[" << cur_infer_id
<< "].";
ICHECK(iter_var->dom->extent.defined())
<< "iter_var->dom->extent is not defined for infer_list_["
<< cur_infer_id << "].";
const int64_t *extent_ptr = as_const_int(iter_var->dom->extent);
ICHECK(extent_ptr != nullptr)
<< "iter_var->dom->extent is not a constant integer, which is "
"required for layout inference.";
// Run InferLayout
auto updates = next->InferLayout(
LayoutInferArgs{target_, thread_bounds, layout_map}, level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one
if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict) {
const FragmentNode *dst_layout = layout.as<FragmentNode>();
const FragmentNode *src_layout =
layout_map[buffer].as<FragmentNode>();
if (as_const_int(dst_layout->ReplicateExtent()) &&
as_const_int(src_layout->ReplicateExtent()) &&
(*as_const_int(dst_layout->ReplicateExtent()) >
*as_const_int(src_layout->ReplicateExtent()))) {
// update map
layout_map.Set(buffer, layout);
continue;
}
}
// If already in map, ensure they are structurally equal
// (zhengju) We can not modify the strict layout map when current
// level is not strict. This check should be done in certain
// conditions, since the strict layout map is not updated in the
// above code when current level is not strict
if (level == InferLevel::kStrict ||
!strict_layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
}
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
if (!update_queue)
continue;
// Check if buffer exists in use_list_
if (!use_list_.count(buffer)) {
LOG(WARNING) << "Layout inference failed for buffer " << buffer
<< ". "
<< "The buffer cannot be inferred with current layout "
"inference rules.";
continue;
}
// Push back into BFS queue
for (int idx : use_list_[buffer]) {
ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer
<< " is negative.";
ICHECK_LT(idx, num_infer)
<< "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << ".";
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
};
auto finish_infer_queue = [&]() {
while (!q.empty()) {
int cur_infer_id = q.front();
q.pop();
// Range check again, just to be safe
ICHECK_GE(cur_infer_id, 0);
ICHECK_LT(cur_infer_id, num_infer);
in_queue[cur_infer_id] = false;
run_infer_step(cur_infer_id, InferLevel::kCommon, true);
}
};
// step 1: infer strict layout
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kStrict, false);
RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map,
q, in_queue);
}
for (const auto &[buffer, layout] : layout_map) {
......@@ -219,13 +232,12 @@ public:
}
// step 2: infer common layout with BFS
finish_infer_queue();
FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q,
in_queue);
// step 3: relax constraints to free and re-run
for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kFree, true);
finish_infer_queue();
}
InferInFreeMode(layout_map, strict_layout_map);
// Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") {
......@@ -291,6 +303,7 @@ private:
addToUseList(buffer.value());
}
}
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
thread_var_vec_.push_back(thread_var_);
if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
......@@ -309,9 +322,14 @@ private:
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_access_ptr())) {
if (!call) {
return std::nullopt;
}
if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value();
return buffer_data_to_buffer_[var];
} else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer;
}
return std::nullopt;
}
......@@ -330,6 +348,7 @@ private:
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() &&
......@@ -379,6 +398,7 @@ private:
}
Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<ObjectRef> infer_list_stmt_;
std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_;
......@@ -391,6 +411,122 @@ private:
Target target_;
LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false};
std::vector<std::unique_ptr<Operator>> BackupInferList() {
std::vector<std::unique_ptr<Operator>> back_infer_list;
back_infer_list.reserve(infer_list_.size());
for (auto &&p : infer_list_) {
back_infer_list.push_back(p->Clone());
}
return back_infer_list;
}
void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {
// Group operators into connected components
UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) {
uf.MakeSet(i);
}
for (const auto &[buffer, infer_indices] : use_list_) {
if (infer_indices.empty())
continue;
// Union all infer_list_ indices that share the same buffer
int first_idx = infer_indices[0];
for (size_t i = 1; i < infer_indices.size(); i++) {
uf.Union(first_idx, infer_indices[i]);
}
}
std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i);
components[root].push_back(i);
}
std::unordered_map<int, std::vector<Buffer>> components_buffers;
for (const auto &[buffer, infer_indices] : use_list_) {
int root = uf.Find(infer_indices[0]);
components_buffers[root].push_back(buffer);
}
// For each component, try each op as root, and determine the least
// replicated one
std::queue<int> q;
std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) {
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on
// write
LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true;
try {
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.
for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue);
}
}
} catch (LayoutConflictException e) {
// such an order fails, try others
do_update = false;
} catch (NormalizeIterException e) {
// such an order encounters iterators that is not normalizable, try
// others e.g. i * 576 % 2048
do_update = false;
}
if (do_update) {
// compute total register number
int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
auto pci = as_const_int(i);
ICHECK(pci != nullptr);
frag_reg_num *= *pci;
}
reg_num += frag_reg_num;
}
}
// if it's any better, update the best_* storage
if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_);
best_layout_map = tmp_layout_map;
min_reg_num = reg_num;
}
}
// recover stateful infer_list_, head on next
infer_list_ = std::move(back_infer_list);
}
if (min_reg_num < INT64_MAX) {
// now apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
}
}
}
};
class LayoutInferencer : public IRMutatorWithAnalyzer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment