"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0e4961551d3b9cd6e766381cb7539531de20450b"
Commit c5a989f5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Simplify buffer_region_to_tile_region function in copy.py (#470)

* Removed redundant logic for handling region extents in the buffer_region_to_tile_region function, streamlining the code for better readability and maintainability.
* Enhanced error handling by focusing on essential checks while eliminating unnecessary complexity related to variable extents.
parent 1f2f1554
...@@ -74,38 +74,9 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s ...@@ -74,38 +74,9 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
mins = [x.min for x in buffer_region.region] mins = [x.min for x in buffer_region.region]
region_extents = [x.extent for x in buffer_region.region] region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len( assert len(region_extents) >= len(
extents), f"region_extents = {region_extents}, extents = {extents}" extents
), f"region_extents must be >= extents, region_extents = {region_extents}, extents = {extents}"
# If region_extents already contains all elements
# of extents (in any order), pass directly
tmp_extents = list(extents)
variable_extent_count = 0
for i in range(len(region_extents)):
v = region_extents[i]
if not isinstance(v, tir.IntImm):
variable_extent_count += 1
continue
if v in tmp_extents:
tmp_extents.remove(v)
elif isinstance(v, tir.IntImm) and v != 1:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}"
)
tmp_len = len(tmp_extents) - variable_extent_count
if tmp_len > 0:
# Otherwise, align extents from the last dimension, region_extents
# can only replace 1 with extents value, otherwise raise error
for i in range(len(extents)):
idx = len(region_extents) - len(extents) + i
if region_extents[idx] != extents[i]:
if region_extents[idx] == 1:
region_extents[idx] = extents[i]
else:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{idx}] = {region_extents[idx]}, extents[{i}] = {extents[i]}"
)
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment