Unverified Commit 17718bec authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346)

* [Refactor] Enhance CopyNode's IterVar Creation and Range Handling

This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation.

* test fix
parent f0c721a4
...@@ -179,15 +179,95 @@ TileOperator CopyNode::Clone() const { ...@@ -179,15 +179,95 @@ TileOperator CopyNode::Clone() const {
* copy operation. * copy operation.
*/ */
Array<IterVar> CopyNode::MakeIterVars() const { Array<IterVar> CopyNode::MakeIterVars() const {
// Choose the range set from the lowest-level memory scope between src and
// dst. Scope levels: global < shared/shared.dyn/shared.tmem < local.fragment
// (fragment)
auto scope_level = [](const Buffer &b) -> int {
String s = b.scope();
if (s == "local.fragment" || s == "local")
return 2;
if (s == "shared" || s == "shared.dyn" || s == "shared.tmem")
return 1;
// default to global level for unknown scopes
return 0;
};
int src_level = scope_level(src);
int dst_level = scope_level(dst);
bool base_is_src = (src_level >= dst_level);
const Array<Range> &base_ranges = base_is_src ? src_range : dst_range;
// Sanity check: when switching away from the original (src_range),
// ensure the chosen base ranges are not provably smaller than the original
// per dimension. This guards against generating undersized loop domains.
// Improved logic: use two pointers to traverse both base_ranges and
// src_range, skipping dimensions with extent == 1. The number of non-1
// extents must match.
arith::Analyzer analyzer;
size_t base_dim = 0, src_dim = 0;
while (base_dim < base_ranges.size() && src_dim < src_range.size()) {
// Skip base extents that are 1
while (base_dim < base_ranges.size() &&
is_one(base_ranges[base_dim]->extent)) {
++base_dim;
}
// Skip src extents that are 1
while (src_dim < src_range.size() && is_one(src_range[src_dim]->extent)) {
++src_dim;
}
// Both indices now at non-1, or at end
if (base_dim < base_ranges.size() && src_dim < src_range.size()) {
PrimExpr base_ext = base_ranges[base_dim]->extent;
PrimExpr src_ext = src_range[src_dim]->extent;
// Only fail if base extent is provably smaller than src extent
if (analyzer.CanProve(base_ext < src_ext)) {
std::ostringstream oss;
oss << "Selected loop range is smaller than original src range at "
"matched non-1 dimension: "
<< "base(extent=" << base_ext
<< ", scope=" << (base_is_src ? src.scope() : dst.scope())
<< ", min=" << base_ranges[base_dim]->min
<< ", base_dim=" << base_dim << ") < src(extent=" << src_ext
<< ", min=" << src_range[src_dim]->min << ", src_dim=" << src_dim
<< ", scope=" << src.scope() << ") for src=" << src->name
<< ", dst=" << dst->name << "\n";
oss << "src buffer: " << src->name << ", scope=" << src.scope() << "\n";
oss << "dst buffer: " << dst->name << ", scope=" << dst.scope() << "\n";
oss << "base_ranges[" << base_dim
<< "]: min=" << base_ranges[base_dim]->min
<< ", extent=" << base_ext << "\n";
oss << "src_ranges[" << src_dim << "]: min=" << src_range[src_dim]->min
<< ", extent=" << src_ext << "\n";
LOG(FATAL) << oss.str();
}
++base_dim;
++src_dim;
}
}
// Any remaining unmatched dimensions in either range must all have extent ==
// 1
while (base_dim < base_ranges.size()) {
ICHECK(is_one(base_ranges[base_dim]->extent))
<< "base_ranges has extra non-1 extent at dim " << base_dim;
++base_dim;
}
while (src_dim < src_range.size()) {
ICHECK(is_one(src_range[src_dim]->extent))
<< "src_range has extra non-1 extent at dim " << src_dim;
++src_dim;
}
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
size_t idx = 0; size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < base_ranges.size(); i++) {
if (is_one(src_range[i]->extent)) if (is_one(base_ranges[i]->extent))
continue; continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); Var var = Var(std::string{char('i' + idx)}, base_ranges[i]->extent->dtype);
idx++; idx++;
loop_vars.push_back( loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); {Range(0, base_ranges[i]->extent), var, IterVarType::kDataPar});
} }
return loop_vars; return loop_vars;
} }
......
...@@ -27,6 +27,22 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -27,6 +27,22 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
Returns: Returns:
tir.Call: A handle to the copy operation tir.Call: A handle to the copy operation
Range handling notes:
- Accepts `Buffer`/`BufferRegion`/`BufferLoad` on either side. Extents are
derived as follows: `Buffer -> shape`, `BufferRegion -> [r.extent]`,
`BufferLoad -> extents from its inferred/encoded region`.
- If both `src` and `dst` are scalar `BufferLoad` without region extents,
lowers to a direct store: `dst[...] = src`.
- If one side is missing extents, it is treated as all-ones with the other
side's rank to enable broadcasting.
- Extents are right-aligned and legalized via `legalize_pairwise_extents`:
per tail-dimension, equal keeps as-is, a `1` broadcasts to the other,
otherwise a conservative `tir.max` is used to remain safe for dynamic
shapes.
- The finalized extents are encoded with `tl.region` via `to_buffer_region`
and passed through to the backend; low-level loop construction and any
scope-specific decisions happen during lowering.
""" """
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
ir.assert_structural_equal(src.shape, dst.shape) ir.assert_structural_equal(src.shape, dst.shape)
...@@ -57,16 +73,11 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, ...@@ -57,16 +73,11 @@ def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
return tir.BufferStore(dst.buffer, src, dst.indices) return tir.BufferStore(dst.buffer, src, dst.indices)
assert src_extent or dst_extent, "Can't deduce copy extents from args" assert src_extent or dst_extent, "Can't deduce copy extents from args"
# Treat missing extent as length-matched ones to enable broadcasting logic. # Treat missing extent as length-matched ones to enable broadcasting.
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
# Align and broadcast extents from the right (tail) side independently # Align and broadcast extents from the right (tail) side.
# for src and dst, so we can pass them unchanged into _to_region.
# Rules per-dim from the right:
# - equal -> keep both
# - one is 1 -> set that side to the other side's dim
# - otherwise -> error
src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent)
# Use legalized extents for src and dst respectively. # Use legalized extents for src and dst respectively.
......
...@@ -46,8 +46,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: ...@@ -46,8 +46,9 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns: Returns:
Buffer: A new buffer view with the specified shape Buffer: A new buffer view with the specified shape
""" """
assert prim_expr_equal(bits_product(shape, src.dtype), assert prim_expr_equal(
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." bits_product(shape, src.dtype), bits_product(src.shape, src.dtype)
), f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {src.dtype}"
return T.Tensor(shape, src.dtype, src.data) return T.Tensor(shape, src.dtype, src.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment