Unverified Commit 921b96a3 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Language] Add shape check in `T.view/reshape` (#1277)

* [Language] Add shape check in T.view/reshape

* address comments
parent 1b0efb65
......@@ -2,6 +2,7 @@ from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import torch
import pytest
def reshape_test(N, M, dtype):
......@@ -262,5 +263,25 @@ def test_reduce_after_reshape():
run_reduce_after_reshape(2048, 64, "float16")
def reshape_shape_mismatch_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M + 1])
T.copy(A_reshaped, B)
return main
def test_reshape_shape_mismatch():
with pytest.raises(AssertionError):
reshape_shape_mismatch_test(1024, 32, "float32")
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import pytest
def view_test(N, M, dtype, new_dtype=None):
......@@ -54,5 +55,35 @@ def test_reshape_view():
run_view(2048, 64, "float16", "float32")
def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M + 1]
if new_dtype:
from tvm import DataType
dtype_src = DataType(dtype)
dtype_dst = DataType(new_dtype)
src_bits = dtype_src.bits
dst_bits = dtype_dst.bits
scale = src_bits / dst_bits
new_shape[-1] = int(M * scale)
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
):
with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype)
T.copy(A_viewed, B)
return main
def test_view_shape_mismatch():
with pytest.raises(AssertionError):
view_shape_mismatch_test(1024, 32, "float32")
if __name__ == "__main__":
tilelang.testing.main()
......@@ -2,6 +2,7 @@
from __future__ import annotations
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op
from tilelang.utils.language import (bits_product, prim_expr_equal)
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
......@@ -45,19 +46,22 @@ def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
Returns:
Buffer: A new buffer view with the specified shape
"""
assert prim_expr_equal(bits_product(shape, src.dtype),
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
return T.Tensor(shape, src.dtype, src.data)
def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer:
"""
Return a Tensor view of the input buffer with an optional new shape and dtype.
"""Return a Tensor view of the input buffer with an optional new shape and dtype.
If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
"""
If `shape` is None the source buffer's shape is used; if `dtype` is None the source buffer's dtype is used. The returned buffer shares the same underlying data as `src` (no copy).
"""
if shape is None:
shape = src.shape
if dtype is None:
dtype = src.dtype
assert prim_expr_equal(bits_product(shape, dtype),
bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
return T.Tensor(shape, dtype, src.data)
......
from __future__ import annotations
from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from functools import reduce
from tvm import IRModule
from tvm import IRModule, DataType
from tvm.tir import PrimFunc
from tvm import ir, tir
......@@ -349,6 +349,17 @@ def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list:
raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}")
def bits_product(shape: list[PrimExpr], dtype: str) -> PrimExpr:
"""
Compute the number of bits in a Buffer (shape with dtype)."""
if len(shape) == 0:
return tir.IntImm("int32", 1)
result = shape[0]
for i in range(1, len(shape)):
result = result * shape[i]
return result * DataType(dtype).bits
def prim_expr_equal(lhs, rhs) -> bool:
"""
Robust equality for PrimExpr shapes/extents.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment