Commit 10804a0d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support region padding when convert buffer load to buffer region (#342)

* Enhance error checking in RegionOp and buffer_load_to_tile_region

- Added detailed error messages to the index size check in `RegionOp` to aid debugging.
- Implemented a check in `buffer_load_to_tile_region` to ensure the length of indices matches extents, with a fallback to expand extents if necessary. This improves robustness in handling buffer loads with mismatched dimensions.

* lint fix
parent 73885cfd
......@@ -53,7 +53,8 @@ RegionOp::RegionOp(Array<PrimExpr> args, BufferMap vmap) {
size_t ndim = n - 2;
auto load = args[0].as<BufferLoadNode>();
ICHECK(load);
ICHECK(load->indices.size() == ndim);
ICHECK(load->indices.size() == ndim)
<< "load->indices.size() = " << load->indices << " ndim = " << ndim;
buffer_ = load->buffer;
access_mask_ = static_cast<int>(*as_const_int(args[1]));
for (size_t i = 0; i < ndim; i++) {
......
......@@ -46,6 +46,17 @@ def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents:
Returns:
tir.Call: A region descriptor for the loaded area
"""
indices = load.indices
if len(indices) > len(extents):
# (f"mismatch between indices and extents for buffer load {load}: indices = {indices}, extents = {extents}, "
# f"region will be expanded in the last 2 dimensions")
new_extents = []
for _ in range(len(indices) - len(extents)):
new_extents.append(1)
for i in range(len(extents)):
new_extents.append(extents[i])
extents = new_extents
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment