"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "5a5ca8e7ff042842013cf7ad1462a5d6c705d39b"
Commit 2ffbd369 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix copy region automation for dynamic extent (#465)

* [Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.

* [Refactor] Improve buffer region validation in copy.py

* Added handling for variable extents in buffer_region_to_tile_region function to enhance type checking and error handling.
* Introduced debug print statements to trace values of region extents and temporary extents during validation.
* Updated logic to account for variable extent counts when determining alignment of extents.

* [Refactor] Remove debug print statements in buffer_region_to_tile_region function

* Eliminated unnecessary print statements that were used for debugging temporary extents and region extents.
* Streamlined the code for better readability while maintaining the existing functionality of buffer region validation.

* [Refactor] Clean up whitespace in buffer_region_to_tile_region function

* Removed an unnecessary blank line in the buffer_region_to_tile_region function to improve code readability and maintain consistency in formatting.
parent f41c467c
......@@ -79,15 +79,22 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
# If region_extents already contains all elements
# of extents (in any order), pass directly
tmp_extents = list(extents)
variable_extent_count = 0
for i in range(len(region_extents)):
v = region_extents[i]
if not isinstance(v, tir.IntImm):
variable_extent_count += 1
continue
if v in tmp_extents:
tmp_extents.remove(v)
elif isinstance(v, tir.IntImm) and v != 1:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}"
)
if len(tmp_extents) > 0:
tmp_len = len(tmp_extents) - variable_extent_count
if tmp_len > 0:
# Otherwise, align extents from the last dimension, region_extents
# can only replace 1 with extents value, otherwise raise error
for i in range(len(extents)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment