"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "b44830905df5cd0088eaa2f820e4686df21407c3"
Unverified Commit 2957afca authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve iterator handling in layout utilities and parallel operations (#1221)

* [Enhancement] Improve iterator handling in layout utilities and parallel operations

* Added a new function, DivideUnusedIterators, to detect per-iterator gaps in fused index expressions, enhancing the accuracy of unused iterator detection.
* Updated CompleteBufferFragment to prefer direct inversion for bijective index mappings and introduced a fallback mechanism for non-bijective cases, improving layout inversion robustness.
* Added a new test for layout inference in fused kernels to ensure correct compilation and execution without layout inversion failures.

* lint fix
parent cf46b7bd
...@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark, ...@@ -115,6 +115,10 @@ Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
return results; return results;
} }
// Heuristic: detect per-iterator gaps ("unused" pieces) even when the iterator
// appears in fused forms across multiple index expressions. We first normalize
// every index into IterSumExpr, collect all splits per source Var, then
// consolidate them to avoid misclassifying a used split as unused.
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
const Array<IterVar> input_iters, const Array<IterVar> input_iters,
Analyzer *analyzer) { Analyzer *analyzer) {
...@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs, ...@@ -134,17 +138,25 @@ Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
} }
for (const IterVar &iter : input_iters) { for (const IterVar &iter : input_iters) {
IterMark iv_mark; // Merge splits from all IterMark that share the same source Var as `iter`.
std::vector<IterSplitExpr> merged_splits;
for (const IterMark &mark : collector.visited_) { for (const IterMark &mark : collector.visited_) {
if (mark->source.as<Var>()->same_as(iter->var)) { // NOLINT(*) auto vexpr = mark->source.as<Var>();
iv_mark = mark; if (vexpr && vexpr.value().same_as(iter->var)) {
break; auto it = collector.mark2splits_.find(mark);
if (it != collector.mark2splits_.end()) {
const auto &vec = it->second;
merged_splits.insert(merged_splits.end(), vec.begin(), vec.end());
}
} }
} }
if (iv_mark.defined()) {
auto splits = if (!merged_splits.empty()) {
get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer); // Use a unified mark (Var + full extent) to compute the missing pieces
// Put the small axis last // so that fused usages are honored as "used" and not reintroduced.
IterMark unified_mark(iter->var, iter->dom->extent);
auto splits = get_unused_iters(unified_mark, merged_splits, analyzer);
// Put the small axis last for a flattened ordering.
results.insert(results.end(), splits.rbegin(), splits.rend()); results.insert(results.end(), splits.rbegin(), splits.rend());
} else if (!is_one(iter->dom->extent)) { } else if (!is_one(iter->dom->extent)) {
auto mark = IterMark(iter->var, iter->dom->extent); auto mark = IterMark(iter->var, iter->dom->extent);
......
...@@ -620,11 +620,66 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ...@@ -620,11 +620,66 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return loop_layout_; return loop_layout_;
} }
// Prefer a simple path: if original 2D indices form a bijective map, invert
// them directly and avoid introducing a synthetic replicate dimension.
{
auto res2d =
arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
arith::IterMapLevel::Bijective,
const_cast<arith::Analyzer *>(&analyzer_));
if (res2d->errors.empty()) {
Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b2 =
loop_layout_->ForwardThread(ind_inv2d->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b2, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr rep_b = MakeFlattenedExpression( PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer]; auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b); bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); Layout layout_before_inv = Layout(loop_vars_, bijective_indice);
// Pre-check cardinality to guard non-bijective combinations after adding
// rep_b.
PrimExpr in_prod = 1;
for (const auto &iv : loop_vars_)
in_prod *= iv->dom->extent;
PrimExpr out_prod = 1;
for (const auto &d : layout_before_inv->OutputShape())
out_prod *= d;
if (!analyzer_.CanProveEqual(in_prod, out_prod)) {
DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling "
"back to no-rep inversion.";
Layout ind_inv_fallback =
Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv_fallback->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
Layout ind_inv = layout_before_inv->Inverse();
PrimExpr indice_rep_extent = PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
......
import pytest
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
tilelang.testing.set_random_seed()
VEC_SIZE = 32
@tilelang.jit
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
T.ceildiv(N, BLOCK_K),
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K
for i, j in T.Parallel(BLOCK_MN, BLOCK_K):
idx = i * BLOCK_K + j
a_out[pid_b, offs_m + i, offs_n + j] = a_fp32_local[idx // VEC_SIZE, idx % VEC_SIZE]
return main
def _require_cuda_tensor(shape, dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")
def test_layout_infer_compiles_and_runs():
B, M, N = 1, 32, 64
BLOCK_MN, BLOCK_K = 32, 64
kernel = fused_index_kernel(B, M, N, BLOCK_MN, BLOCK_K)
a = _require_cuda_tensor((B, M, N), torch.bfloat16)
a_out = torch.empty((B, M, N), dtype=torch.float32, device=a.device)
# Ensure kernel compiles and executes without layout inversion failure
kernel(a, a_out)
assert a_out.shape == a.shape
assert a_out.dtype == torch.float32
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment