Unverified Commit c2fe91e0 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Enhancement] Add shape checking for reduce options (#748)



* Add shape checking for reduce options

* lint fix

* Handle special case reducing into shape-1 tensor

Allow reducing [X, d, Y] into [X, Y] or [X, 1, Y]

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent e68fdab8
...@@ -24,6 +24,16 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea ...@@ -24,6 +24,16 @@ def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clea
Returns: Returns:
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
# input shape: [X, d, Y], expected output shape: [X, Y] or [X, 1, Y]
expected_shapes = [
buffer.shape[:dim] + buffer.shape[dim + 1:],
buffer.shape[:dim] + [1] + buffer.shape[dim + 1:]
]
if list(out.shape) not in expected_shapes:
expected_shapes_str = ' or '.join(map(str, expected_shapes))
raise ValueError(
f"Invalid reduce output shape, buffer shape is {buffer.shape}, dim is {dim}, "
f"output shape is {out.shape}, expected shapes are {expected_shapes_str}")
buffer = buffer.access_ptr("r") buffer = buffer.access_ptr("r")
out = out.access_ptr("w") out = out.access_ptr("w")
return tir.call_intrin( return tir.call_intrin(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment