Unverified Commit 6654064d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Enhance Free Layout Inference (#1375)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Enhancement] Add injective layout detection and exception handling

- Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts.
- Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts.
- Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information.
- Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts.

* remove debug print

* remove debug print

* remove debug print

* minor layout fix

* fix for T.view

* [Enhancement] Improve injective layout detection in FragmentNode

- Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly.
- Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions.
- Minor update to the test file to include the tilelang testing module.

* [Refactor] Simplify layout inference for bulk copy operations

- Removed unnecessary conditions for bulk load/store operations in the layout inference logic.
- Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability.

* remove debug print

* [Enhancement] Introduce layout-related exceptions and improve error handling

- Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations.
- Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail.
- Removed redundant exception class definitions from `parallel.h` to streamline code organization.
parent 92121fc6
...@@ -82,6 +82,7 @@ def postprocess( ...@@ -82,6 +82,7 @@ def postprocess(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
}) })
def bwd( def bwd(
B, B,
...@@ -159,9 +160,8 @@ def bwd( ...@@ -159,9 +160,8 @@ def bwd(
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype) acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype)
acc_dkv_tail_shared = T.view( acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype)
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
max_kv_i = s_i max_kv_i = s_i
......
...@@ -297,13 +297,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -297,13 +297,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
} }
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: if shape is the same, return the original layout // Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) { if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this); return ffi::GetRef<Layout>(this);
} }
// Step 1. Prove the product of InputShape is equal to the product of shape // Step 1. Prove the product relation holds under rescale:
// prod(InputShape) * rescale_num == prod(shape) * rescale_den
PrimExpr input_shape_product = Integer(1); PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) { for (const auto &dim : InputShape()) {
input_shape_product *= dim; input_shape_product *= dim;
...@@ -317,8 +321,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -317,8 +321,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
// potential null dereference paths flagged by static analysis. // potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer; arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product)) ICHECK(az->CanProveEqual(input_shape_product * rescale_num,
<< "InputShape() = " << InputShape() << " shape = " << shape; shape_product * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den;
// Step 2. Create new forward indices by reshaping // Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable // For each dimension in the new shape, we create a placeholder variable
...@@ -339,13 +345,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -339,13 +345,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
} }
flat_index = flat_index + new_vars[i] * stride; flat_index = flat_index + new_vars[i] * stride;
} }
// Convert new flat index (in units of new elements) to the old flat index
// (in units of old elements) using the rational rescale factor.
// old_flat = floor((flat_index * rescale_den) / rescale_num)
PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num);
// Step 4. Convert flat index back to original shape indices // Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]: // For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm) // i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm) // i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ... // ...
Array<PrimExpr> original_indices; Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index; PrimExpr remaining = old_flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1); PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) { for (size_t j = i + 1; j < InputShape().size(); ++j) {
...@@ -373,7 +383,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape, ...@@ -373,7 +383,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
} }
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: identical input shape, return self // Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) { if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this); return ffi::GetRef<Fragment>(this);
...@@ -390,8 +403,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -390,8 +403,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
// Use provided analyzer if present, otherwise a local fallback. // Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer; arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer; arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod)) ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape << "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den
<< " input fragment layout is = " << DebugOutput(); << " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices // 2) Build flat index from new-shape indices
...@@ -414,9 +428,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape, ...@@ -414,9 +428,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j]; stride = stride * shape[j];
flat = flat + new_vars[i] * stride; flat = flat + new_vars[i] * stride;
} }
// Convert to old flat index units using the rational rescale factor.
// old_flat = floor((flat * rescale_den) / rescale_num)
PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num);
// 3) Recover original indices from flat index // 3) Recover original indices from flat index
Array<PrimExpr> orig_indices; Array<PrimExpr> orig_indices;
PrimExpr remain = flat; PrimExpr remain = old_flat;
for (size_t i = 0; i < InputShape().size(); ++i) { for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1); PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) for (size_t j = i + 1; j < InputShape().size(); ++j)
...@@ -536,6 +553,52 @@ bool FragmentNode::IsCompletedReplicated() const { ...@@ -536,6 +553,52 @@ bool FragmentNode::IsCompletedReplicated() const {
ReplicationPlaceholder()); ReplicationPlaceholder());
} }
arith::IterMapResult FragmentNode::DetectInjective() const {
// lei:To perform injective check, we need to reverse the layout
// and use surjective check, now we use bijective check for convenience
// can be relaxed in future
arith::Analyzer analyzer;
// Build a flat indices array: [forward_thread_, forward_index_[...]]
Array<PrimExpr> indices;
indices.push_back(forward_thread_);
for (const auto &e : forward_index_) {
indices.push_back(e);
}
// Mirror Layout::InverseWithLevel(): if any participating shape is
// symbolic, relax to NoCheck and rely on runtime guards elsewhere.
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(InputShape());
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
// Also consider replicate size for fragments
if (!as_const_int(ReplicateExtent())) {
symbolic_dims.push_back(ReplicateExtent());
}
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
DLOG(WARNING)
<< "Fragment::DetectInjective on symbolic layout, falling back to "
<< "NoCheck; symbolic dims: " << symbolic_dims;
}
return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
}
PrimExpr FragmentNode::ThreadExtent() const { PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1); Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer; arith::Analyzer analyzer;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TVM_TL_LAYOUT_LAYOUT_H_ #ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_ #define TVM_TL_LAYOUT_LAYOUT_H_
#include <exception>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h> #include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h> #include <tvm/ffi/object.h>
...@@ -18,6 +19,25 @@ namespace tl { ...@@ -18,6 +19,25 @@ namespace tl {
using namespace tir; using namespace tir;
// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class LoopLayoutInjectiveException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class Layout; class Layout;
class Fragment; class Fragment;
...@@ -42,8 +62,18 @@ public: ...@@ -42,8 +62,18 @@ public:
virtual Layout Inverse() const; virtual Layout Inverse() const;
// Reshape the layout to a new logical shape. When aliasing buffers of
// different dtypes, the element count may change while the underlying
// byte-size stays equal. Use rescale_num/rescale_den to represent the
// ratio between the old element size and the new element size in bytes.
// Specifically, define factor = rescale_num / rescale_den where:
// new_num_elems = old_num_elems * factor
// For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
// i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
virtual Layout Reshape(const Array<PrimExpr> &shape, virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const; arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const; virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
...@@ -86,7 +116,9 @@ public: ...@@ -86,7 +116,9 @@ public:
Layout Inverse() const final; Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const; Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final; std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
...@@ -116,6 +148,8 @@ public: ...@@ -116,6 +148,8 @@ public:
bool IsCompletedReplicated() const; bool IsCompletedReplicated() const;
arith::IterMapResult DetectInjective() const;
static void RegisterReflection(); static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode); TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
......
...@@ -551,7 +551,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -551,7 +551,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
// This must be a global/shared layout, so we can skip the parallel op // This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout // layout inference (parallel layout inference only annotate the loop layout
// and the register layout). // and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad; bool is_load =
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst; Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src; Buffer shared_tensor = is_load ? dst : src;
// check shared layout is non-swizzle // check shared layout is non-swizzle
...@@ -561,6 +562,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -561,6 +562,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
Layout linear_layout = ComputeLinearLayout(shared_tensor); Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}}); return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
} }
return {};
} }
// for LDSM/STSM, the layout was deduced from register layout // for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy // so we can directly apply the layout of normal copy
......
...@@ -214,6 +214,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -214,6 +214,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const { InferLevel level) const {
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
if (level == InferLevel::kStrict) { if (level == InferLevel::kStrict) {
LayoutMap results; LayoutMap results;
// Deduce buffers that should be complicated replicated. // Deduce buffers that should be complicated replicated.
...@@ -562,6 +563,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -562,6 +563,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} else { } else {
return {}; return {};
} }
// check loop_layout_ is injective
auto injective_res = loop_layout_->DetectInjective();
if (!injective_res->errors.empty()) {
std::ostringstream oss;
oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
<< '\n'
<< " errors: " << injective_res->errors << '\n'
<< " loop AST: " << root_;
throw LoopLayoutInjectiveException(oss.str());
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
......
...@@ -24,15 +24,6 @@ namespace tl { ...@@ -24,15 +24,6 @@ namespace tl {
using namespace tir; using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices, Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices, Array<PrimExpr> large_frag_indices,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <algorithm> #include <algorithm>
#include <deque>
#include <memory> #include <memory>
#include <queue> #include <queue>
...@@ -72,7 +73,7 @@ public: ...@@ -72,7 +73,7 @@ public:
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map, LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) { std::deque<int> &q, std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size(); auto num_infer = infer_list_.size();
// Range check for cur_infer_id // Range check for cur_infer_id
...@@ -112,9 +113,9 @@ public: ...@@ -112,9 +113,9 @@ public:
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
cur_analyzer, buffer_oob}, cur_analyzer, buffer_oob},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
// Basic validity checks // Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
...@@ -140,8 +141,11 @@ public: ...@@ -140,8 +141,11 @@ public:
} }
} }
Layout target_layout = Layout target_layout =
shapes_equal ? src_layout shapes_equal
: src_layout->Reshape(sib->shape, &analyzer_); ? src_layout
: src_layout->Reshape(sib->shape, &analyzer_,
Integer(src_buffer->dtype.bytes()),
Integer(sib->dtype.bytes()));
if (layout_map.count(sib)) { if (layout_map.count(sib)) {
ICHECK(target_layout->IsEqual(layout_map[sib].get())) ICHECK(target_layout->IsEqual(layout_map[sib].get()))
<< "Get different layout for alias buffer " << sib << "Get different layout for alias buffer " << sib
...@@ -152,10 +156,7 @@ public: ...@@ -152,10 +156,7 @@ public:
layout_map.Set(sib, target_layout); layout_map.Set(sib, target_layout);
if (update_queue && use_list_.count(sib)) { if (update_queue && use_list_.count(sib)) {
for (int idx : use_list_[sib]) { for (int idx : use_list_[sib]) {
if (!in_queue[idx] && idx != cur_infer_id) { EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map);
in_queue[idx] = true;
q.push(idx);
}
} }
} }
} }
...@@ -233,22 +234,20 @@ public: ...@@ -233,22 +234,20 @@ public:
<< "Index in use_list_ for buffer " << buffer << "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << "."; << " out of range: " << idx << " >= " << num_infer << ".";
if (!in_queue[idx] && idx != cur_infer_id) { EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map);
in_queue[idx] = true;
q.push(idx);
}
} }
} }
} }
}; };
void FinishInferQueue(InferLevel level, LayoutMap &layout_map, void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
const LayoutMap &strict_layout_map, std::queue<int> &q, const LayoutMap &strict_layout_map, std::deque<int> &q,
std::vector<bool> &in_queue) { std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size(); auto num_infer = infer_list_.size();
while (!q.empty()) { while (!q.empty()) {
int cur_infer_id = q.front(); int cur_infer_id = q.front();
q.pop(); q.pop_front();
// Range check again, just to be safe // Range check again, just to be safe
ICHECK_GE(cur_infer_id, 0); ICHECK_GE(cur_infer_id, 0);
ICHECK_LT(cur_infer_id, num_infer); ICHECK_LT(cur_infer_id, num_infer);
...@@ -289,7 +288,7 @@ public: ...@@ -289,7 +288,7 @@ public:
int num_infer = infer_list_.size(); int num_infer = infer_list_.size();
// Prepare BFS queue for iterative inference // Prepare BFS queue for iterative inference
std::queue<int> q; std::deque<int> q;
std::vector<bool> in_queue(num_infer, true); std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) { for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid // Check that each infer_list_ entry is valid
...@@ -301,7 +300,7 @@ public: ...@@ -301,7 +300,7 @@ public:
if (!thread_var_vec_[i].defined() && skip_thread_partition_) { if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
thread_var_vec_[i] = thread_var_; thread_var_vec_[i] = thread_var_;
} }
q.push(i); q.push_back(i);
} }
// step 1: infer strict layout // step 1: infer strict layout
...@@ -352,10 +351,12 @@ public: ...@@ -352,10 +351,12 @@ public:
} }
} }
Layout reshaped = Layout reshaped = shapes_equal
shapes_equal ? rep_layout.value()
? rep_layout.value() : rep_layout.value()->Reshape(
: rep_layout.value()->Reshape(buf->shape, &analyzer_); buf->shape, &analyzer_,
Integer(rep.value()->dtype.bytes()),
Integer(buf->dtype.bytes()));
layout_map.Set(buf, reshaped); layout_map.Set(buf, reshaped);
} }
} }
...@@ -431,6 +432,38 @@ private: ...@@ -431,6 +432,38 @@ private:
return buffer_map; return buffer_map;
} }
// Return true if all buffers that this op (idx) touches already have
// inferred layouts in layout_map. Used to prioritize enqueue order.
bool ShouldPrioritize(int idx, const LayoutMap &layout_map) const {
auto it = op_touched_buffers_.find(idx);
if (it == op_touched_buffers_.end() || it->second.empty())
return false;
for (const auto &buf : it->second) {
if (!layout_map.count(buf))
return false;
}
return true;
}
// Enqueue idx to q with priority if all its buffers already
// have layouts. Also guards against duplicates and self-enqueue.
void EnqueueWithPriority(int idx, std::deque<int> &q,
std::vector<bool> &in_queue, int cur_infer_id,
const LayoutMap &layout_map) const {
if (idx == cur_infer_id)
return;
if (idx < 0 || idx >= static_cast<int>(in_queue.size()))
return;
if (in_queue[idx])
return;
in_queue[idx] = true;
if (ShouldPrioritize(idx, layout_map)) {
q.push_front(idx);
} else {
q.push_back(idx);
}
}
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
...@@ -536,11 +569,28 @@ private: ...@@ -536,11 +569,28 @@ private:
} }
void addToUseList(const Buffer &buffer) { void addToUseList(const Buffer &buffer) {
// buffer scope must be local.fragment
if (buffer.scope() != "local.fragment") {
return;
}
int infer_idx = infer_list_.size(); int infer_idx = infer_list_.size();
if (use_list_.find(buffer) == use_list_.end()) { if (use_list_.find(buffer) == use_list_.end()) {
use_list_[buffer] = {}; use_list_[buffer] = {};
} }
use_list_[buffer].push_back(infer_idx); use_list_[buffer].push_back(infer_idx);
// Track which buffers this op (infer_idx) touches for prioritization.
// Avoid duplicates.
auto &vec = op_touched_buffers_[infer_idx];
bool exists = false;
for (const auto &b : vec) {
if (b.same_as(buffer)) {
exists = true;
break;
}
}
if (!exists)
vec.push_back(buffer);
} }
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
...@@ -549,6 +599,71 @@ private: ...@@ -549,6 +599,71 @@ private:
for (const auto &[buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
PostOrderVisit(op->body, [this](const ObjectRef &node) {
if (auto *buffer_load = node.as<BufferLoadNode>()) {
if (buffer_load->buffer.defined() &&
buffer_load->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(buffer_load->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[buffer_load->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(buffer_load->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(buffer_load->buffer);
buffer_data_to_buffers_.Set(buffer_load->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< buffer_load->buffer
<< " buffer.get() = " << buffer_load->buffer.get()
<< " data = " << buffer_load->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(buffer_load->buffer->data,
{buffer_load->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer "
<< buffer_load->buffer
<< " buffer.get() = " << buffer_load->buffer.get()
<< " data = " << buffer_load->buffer->data.get();
}
}
} else if (auto *buffer_store = node.as<BufferStoreNode>()) {
if (buffer_store->buffer.defined() &&
buffer_store->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(buffer_store->buffer->data)) {
auto buffers =
buffer_data_to_buffers_[buffer_store->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(buffer_store->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(buffer_store->buffer);
buffer_data_to_buffers_.Set(buffer_store->buffer->data,
buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< buffer_store->buffer
<< " buffer.get() = " << buffer_store->buffer.get()
<< " data = " << buffer_store->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(buffer_store->buffer->data,
{buffer_store->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer "
<< buffer_store->buffer
<< " buffer.get() = " << buffer_store->buffer.get()
<< " data = " << buffer_store->buffer->data.get();
}
}
}
});
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op)); infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
...@@ -615,7 +730,11 @@ private: ...@@ -615,7 +730,11 @@ private:
if (shapes_equal) { if (shapes_equal) {
annotated_layout_map_.Set(buffer, layout); annotated_layout_map_.Set(buffer, layout);
} else { } else {
auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_); // Use the first buffer sharing this var as the base for dtype ratio
int base_bytes = buffers[0]->dtype.bytes();
auto reshaped_layout =
layout->Reshape(buffer->shape, &analyzer_, Integer(base_bytes),
Integer(buffer->dtype.bytes()));
annotated_layout_map_.Set(buffer, reshaped_layout); annotated_layout_map_.Set(buffer, reshaped_layout);
} }
} }
...@@ -699,6 +818,8 @@ private: ...@@ -699,6 +818,8 @@ private:
std::vector<TileOperator> infer_list_; std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_; use_list_;
// Per-op list of buffers it touches (fragment scope), used for prioritization
std::unordered_map<int, std::vector<Buffer>> op_touched_buffers_;
// This is a workaround for cpu backend, // This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop. // we need to define a thread_var for the serial loop.
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
...@@ -765,6 +886,7 @@ private: ...@@ -765,6 +886,7 @@ private:
} }
} }
} }
std::unordered_map<int, std::vector<int>> components; std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) { for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i); int root = uf.Find(i);
...@@ -781,7 +903,7 @@ private: ...@@ -781,7 +903,7 @@ private:
// For each component, try each op as root, and determine the least // For each component, try each op as root, and determine the least
// replicated one // replicated one
std::queue<int> q; std::deque<int> q;
std::vector<bool> in_queue(infer_list_.size(), false); std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) { for (auto &&[root, members] : components) {
...@@ -795,7 +917,7 @@ private: ...@@ -795,7 +917,7 @@ private:
// Try each member as the root of inference for this component // Try each member as the root of inference for this component
for (int attempt_infer_root : members) { for (int attempt_infer_root : members) {
DLOG(INFO) << "----------------------- try root " << attempt_infer_root DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n'; << " members " << members.size() << '\n';
// Backup the current infer_list_ state // Backup the current infer_list_ state
auto back_infer_list = BackupInferList(); auto back_infer_list = BackupInferList();
// Copy the current layout_map for temporary use // Copy the current layout_map for temporary use
...@@ -826,6 +948,10 @@ private: ...@@ -826,6 +948,10 @@ private:
do_update = false; do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException " DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n'; << e.what() << '\n';
} catch (const LoopLayoutInjectiveException &e) {
do_update = false;
DLOG(INFO) << "attempt failed due to LoopLayoutInjectiveException "
<< e.what() << '\n';
} }
if (do_update) { if (do_update) {
......
import tilelang import tilelang
import tilelang.testing
import tilelang.language as T import tilelang.language as T
import pytest import pytest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment