Unverified Commit 6654064d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Enhance Free Layout Inference (#1375)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix

* [Enhancement] Add injective layout detection and exception handling

- Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts.
- Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts.
- Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information.
- Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts.

* remove debug print

* remove debug print

* remove debug print

* minor layout fix

* fix for T.view

* [Enhancement] Improve injective layout detection in FragmentNode

- Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly.
- Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions.
- Minor update to the test file to include the tilelang testing module.

* [Refactor] Simplify layout inference for bulk copy operations

- Removed unnecessary conditions for bulk load/store operations in the layout inference logic.
- Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability.

* remove debug print

* [Enhancement] Introduce layout-related exceptions and improve error handling

- Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations.
- Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail.
- Removed redundant exception class definitions from `parallel.h` to streamline code organization.
parent 92121fc6
......@@ -82,6 +82,7 @@ def postprocess(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
})
def bwd(
B,
......@@ -159,9 +160,8 @@ def bwd(
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype)
acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype)
max_kv_i = s_i
......
......@@ -297,13 +297,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
}
Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}
// Step 1. Prove the product of InputShape is equal to the product of shape
// Step 1. Prove the product relation holds under rescale:
// prod(InputShape) * rescale_num == prod(shape) * rescale_den
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
......@@ -317,8 +321,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
ICHECK(az->CanProveEqual(input_shape_product * rescale_num,
shape_product * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den;
// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
......@@ -339,13 +345,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}
flat_index = flat_index + new_vars[i] * stride;
}
// Convert new flat index (in units of new elements) to the old flat index
// (in units of old elements) using the rational rescale factor.
// old_flat = floor((flat_index * rescale_den) / rescale_num)
PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num);
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
PrimExpr remaining = old_flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
......@@ -373,7 +383,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}
Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {
// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
......@@ -390,8 +403,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den
<< " input fragment layout is = " << DebugOutput();
// 2) Build flat index from new-shape indices
......@@ -414,9 +428,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// Convert to old flat index units using the rational rescale factor.
// old_flat = floor((flat * rescale_den) / rescale_num)
PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num);
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
PrimExpr remain = old_flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
......@@ -536,6 +553,52 @@ bool FragmentNode::IsCompletedReplicated() const {
ReplicationPlaceholder());
}
arith::IterMapResult FragmentNode::DetectInjective() const {
// lei:To perform injective check, we need to reverse the layout
// and use surjective check, now we use bijective check for convenience
// can be relaxed in future
arith::Analyzer analyzer;
// Build a flat indices array: [forward_thread_, forward_index_[...]]
Array<PrimExpr> indices;
indices.push_back(forward_thread_);
for (const auto &e : forward_index_) {
indices.push_back(e);
}
// Mirror Layout::InverseWithLevel(): if any participating shape is
// symbolic, relax to NoCheck and rely on runtime guards elsewhere.
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(InputShape());
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
// Also consider replicate size for fragments
if (!as_const_int(ReplicateExtent())) {
symbolic_dims.push_back(ReplicateExtent());
}
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
DLOG(WARNING)
<< "Fragment::DetectInjective on symbolic layout, falling back to "
<< "NoCheck; symbolic dims: " << symbolic_dims;
}
return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
}
PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
......
......@@ -6,6 +6,7 @@
#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_
#include <exception>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
......@@ -18,6 +19,25 @@ namespace tl {
using namespace tir;
// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class LoopLayoutInjectiveException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
class Layout;
class Fragment;
......@@ -42,8 +62,18 @@ public:
virtual Layout Inverse() const;
// Reshape the layout to a new logical shape. When aliasing buffers of
// different dtypes, the element count may change while the underlying
// byte-size stays equal. Use rescale_num/rescale_den to represent the
// ratio between the old element size and the new element size in bytes.
// Specifically, define factor = rescale_num / rescale_den where:
// new_num_elems = old_num_elems * factor
// For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
// i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
......@@ -86,7 +116,9 @@ public:
Layout Inverse() const final;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
......@@ -116,6 +148,8 @@ public:
bool IsCompletedReplicated() const;
arith::IterMapResult DetectInjective() const;
static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
......
......@@ -551,7 +551,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
// This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout
// and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad;
bool is_load =
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
// check shared layout is non-swizzle
......@@ -561,6 +562,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
}
return {};
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
......
......@@ -214,6 +214,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined())
return {};
if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that should be complicated replicated.
......@@ -562,6 +563,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} else {
return {};
}
// check loop_layout_ is injective
auto injective_res = loop_layout_->DetectInjective();
if (!injective_res->errors.empty()) {
std::ostringstream oss;
oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
<< '\n'
<< " errors: " << injective_res->errors << '\n'
<< " loop AST: " << root_;
throw LoopLayoutInjectiveException(oss.str());
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
......
......@@ -24,15 +24,6 @@ namespace tl {
using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
......
......@@ -12,6 +12,7 @@
#include <tvm/tir/utils.h>
#include <algorithm>
#include <deque>
#include <memory>
#include <queue>
......@@ -72,7 +73,7 @@ public:
void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
LayoutMap &layout_map, const LayoutMap &strict_layout_map,
std::queue<int> &q, std::vector<bool> &in_queue) {
std::deque<int> &q, std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size();
// Range check for cur_infer_id
......@@ -112,9 +113,9 @@ public:
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
cur_analyzer, buffer_oob},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
......@@ -140,8 +141,11 @@ public:
}
}
Layout target_layout =
shapes_equal ? src_layout
: src_layout->Reshape(sib->shape, &analyzer_);
shapes_equal
? src_layout
: src_layout->Reshape(sib->shape, &analyzer_,
Integer(src_buffer->dtype.bytes()),
Integer(sib->dtype.bytes()));
if (layout_map.count(sib)) {
ICHECK(target_layout->IsEqual(layout_map[sib].get()))
<< "Get different layout for alias buffer " << sib
......@@ -152,10 +156,7 @@ public:
layout_map.Set(sib, target_layout);
if (update_queue && use_list_.count(sib)) {
for (int idx : use_list_[sib]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map);
}
}
}
......@@ -233,22 +234,20 @@ public:
<< "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << ".";
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
EnqueueWithPriority(idx, q, in_queue, cur_infer_id, layout_map);
}
}
}
};
void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
const LayoutMap &strict_layout_map, std::queue<int> &q,
const LayoutMap &strict_layout_map, std::deque<int> &q,
std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size();
while (!q.empty()) {
int cur_infer_id = q.front();
q.pop();
q.pop_front();
// Range check again, just to be safe
ICHECK_GE(cur_infer_id, 0);
ICHECK_LT(cur_infer_id, num_infer);
......@@ -289,7 +288,7 @@ public:
int num_infer = infer_list_.size();
// Prepare BFS queue for iterative inference
std::queue<int> q;
std::deque<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid
......@@ -301,7 +300,7 @@ public:
if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
thread_var_vec_[i] = thread_var_;
}
q.push(i);
q.push_back(i);
}
// step 1: infer strict layout
......@@ -352,10 +351,12 @@ public:
}
}
Layout reshaped =
shapes_equal
Layout reshaped = shapes_equal
? rep_layout.value()
: rep_layout.value()->Reshape(buf->shape, &analyzer_);
: rep_layout.value()->Reshape(
buf->shape, &analyzer_,
Integer(rep.value()->dtype.bytes()),
Integer(buf->dtype.bytes()));
layout_map.Set(buf, reshaped);
}
}
......@@ -431,6 +432,38 @@ private:
return buffer_map;
}
// Return true if all buffers that this op (idx) touches already have
// inferred layouts in layout_map. Used to prioritize enqueue order.
bool ShouldPrioritize(int idx, const LayoutMap &layout_map) const {
auto it = op_touched_buffers_.find(idx);
if (it == op_touched_buffers_.end() || it->second.empty())
return false;
for (const auto &buf : it->second) {
if (!layout_map.count(buf))
return false;
}
return true;
}
// Enqueue idx to q with priority if all its buffers already
// have layouts. Also guards against duplicates and self-enqueue.
void EnqueueWithPriority(int idx, std::deque<int> &q,
std::vector<bool> &in_queue, int cur_infer_id,
const LayoutMap &layout_map) const {
if (idx == cur_infer_id)
return;
if (idx < 0 || idx >= static_cast<int>(in_queue.size()))
return;
if (in_queue[idx])
return;
in_queue[idx] = true;
if (ShouldPrioritize(idx, layout_map)) {
q.push_front(idx);
} else {
q.push_back(idx);
}
}
void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function.
......@@ -536,11 +569,28 @@ private:
}
void addToUseList(const Buffer &buffer) {
// buffer scope must be local.fragment
if (buffer.scope() != "local.fragment") {
return;
}
int infer_idx = infer_list_.size();
if (use_list_.find(buffer) == use_list_.end()) {
use_list_[buffer] = {};
}
use_list_[buffer].push_back(infer_idx);
// Track which buffers this op (infer_idx) touches for prioritization.
// Avoid duplicates.
auto &vec = op_touched_buffers_[infer_idx];
bool exists = false;
for (const auto &b : vec) {
if (b.same_as(buffer)) {
exists = true;
break;
}
}
if (!exists)
vec.push_back(buffer);
}
void VisitStmt_(const ForNode *op) final {
......@@ -549,6 +599,71 @@ private:
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
PostOrderVisit(op->body, [this](const ObjectRef &node) {
if (auto *buffer_load = node.as<BufferLoadNode>()) {
if (buffer_load->buffer.defined() &&
buffer_load->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(buffer_load->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[buffer_load->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(buffer_load->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(buffer_load->buffer);
buffer_data_to_buffers_.Set(buffer_load->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< buffer_load->buffer
<< " buffer.get() = " << buffer_load->buffer.get()
<< " data = " << buffer_load->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(buffer_load->buffer->data,
{buffer_load->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer "
<< buffer_load->buffer
<< " buffer.get() = " << buffer_load->buffer.get()
<< " data = " << buffer_load->buffer->data.get();
}
}
} else if (auto *buffer_store = node.as<BufferStoreNode>()) {
if (buffer_store->buffer.defined() &&
buffer_store->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(buffer_store->buffer->data)) {
auto buffers =
buffer_data_to_buffers_[buffer_store->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(buffer_store->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(buffer_store->buffer);
buffer_data_to_buffers_.Set(buffer_store->buffer->data,
buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< buffer_store->buffer
<< " buffer.get() = " << buffer_store->buffer.get()
<< " data = " << buffer_store->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(buffer_store->buffer->data,
{buffer_store->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer "
<< buffer_store->buffer
<< " buffer.get() = " << buffer_store->buffer.get()
<< " data = " << buffer_store->buffer->data.get();
}
}
}
});
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_);
......@@ -615,7 +730,11 @@ private:
if (shapes_equal) {
annotated_layout_map_.Set(buffer, layout);
} else {
auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_);
// Use the first buffer sharing this var as the base for dtype ratio
int base_bytes = buffers[0]->dtype.bytes();
auto reshaped_layout =
layout->Reshape(buffer->shape, &analyzer_, Integer(base_bytes),
Integer(buffer->dtype.bytes()));
annotated_layout_map_.Set(buffer, reshaped_layout);
}
}
......@@ -699,6 +818,8 @@ private:
std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_;
// Per-op list of buffers it touches (fragment scope), used for prioritization
std::unordered_map<int, std::vector<Buffer>> op_touched_buffers_;
// This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
......@@ -765,6 +886,7 @@ private:
}
}
}
std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i);
......@@ -781,7 +903,7 @@ private:
// For each component, try each op as root, and determine the least
// replicated one
std::queue<int> q;
std::deque<int> q;
std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) {
......@@ -795,7 +917,7 @@ private:
// Try each member as the root of inference for this component
for (int attempt_infer_root : members) {
DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n';
<< " members " << members.size() << '\n';
// Backup the current infer_list_ state
auto back_infer_list = BackupInferList();
// Copy the current layout_map for temporary use
......@@ -826,6 +948,10 @@ private:
do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n';
} catch (const LoopLayoutInjectiveException &e) {
do_update = false;
DLOG(INFO) << "attempt failed due to LoopLayoutInjectiveException "
<< e.what() << '\n';
}
if (do_update) {
......
import tilelang
import tilelang.testing
import tilelang.language as T
import pytest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment