"include/vscode:/vscode.git/clone" did not exist on "56863b9a06b8a94c3982a6b1de83b702868d8461"
Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
......@@ -17,11 +17,11 @@ def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func
def flash_fwd(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
Output: T.Buffer(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=32) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
......@@ -86,9 +86,9 @@ def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
@T.prim_func
def flash_bwd_prep(
O: T.Buffer(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
......@@ -121,8 +121,8 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
@T.prim_func
def flash_bwd_post(
dQ: T.Buffer(shape, accum_dtype), # type: ignore
dQ_out: T.Buffer(shape, dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
......@@ -143,15 +143,15 @@ def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
@T.prim_func
def flash_bwd(
Q: T.Buffer(shape, dtype), # type: ignore
K: T.Buffer(shape, dtype), # type: ignore
V: T.Buffer(shape, dtype), # type: ignore
dO: T.Buffer(shape, dtype), # type: ignore
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Buffer(shape, accum_dtype), # type: ignore
dK: T.Buffer(shape, dtype), # type: ignore
dV: T.Buffer(shape, dtype), # type: ignore
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
......
......@@ -6,9 +6,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -10,8 +10,8 @@ def alloc_var(
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
......@@ -50,8 +50,8 @@ def alloc_var_add(
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
......
import tilelang.testing
from tilelang.utils.tensor import map_torch_type
def clamp_within_bounds(
N,
block_N,
......@@ -12,8 +13,8 @@ def clamp_within_bounds(
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N,), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
A_shared = T.alloc_shared([block_N], dtype)
......@@ -55,8 +56,8 @@ def clamp_value_range(
@T.prim_func
def main(
A: T.Buffer((1, N), dtype),
B: T.Buffer((1, N), dtype),
A: T.Tensor((1, N), dtype),
B: T.Tensor((1, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=block_N) as bx:
# A_shared = T.alloc_shared([1, block_N], dtype=dtype)
......@@ -91,7 +92,7 @@ def run_clamp_value_range(
# Convert string dtype to torch.dtype
torch_dtype = map_torch_type(dtype)
def ref_program(A):
min_val = torch.min(A) * 0.5
max_val = torch.max(A) * 0.5
......
......@@ -6,9 +6,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -17,9 +17,9 @@ def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype
n: T.int32,
k: T.int32,
):
A = T.Tensor.from_ptr(a_ptr, (m, k), dtype)
B = T.Tensor.from_ptr(b_ptr, (k, n), dtype)
C = T.Tensor.from_ptr(c_ptr, (m, n), accum_dtype)
A = T.make_tensor(a_ptr, (m, k), dtype)
B = T.make_tensor(b_ptr, (k, n), dtype)
C = T.make_tensor(c_ptr, (m, n), accum_dtype)
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -8,8 +8,8 @@ def reduce_max_test(M, N, dtype="float16"):
@T.prim_func
def main(
A: T.Buffer((M, N), dtype),
B: T.Buffer((M,), dtype),
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1) as _:
A_local = T.alloc_fragment((M, N), dtype)
......
......@@ -8,8 +8,8 @@ def reshape_test(N, M, dtype):
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N // M, M), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_reshaped = T.reshape(A, [N // M, M])
......@@ -40,8 +40,8 @@ def reshape_test_smem(N, M, dtype):
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer((N // M, M), dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor((N // M, M), dtype),
):
with T.Kernel(1) as _:
A_shared = T.alloc_shared((N,), dtype)
......
......@@ -18,8 +18,8 @@ def view_test(N, M, dtype, new_dtype=None):
@T.prim_func
def main(
A: T.Buffer((N,), dtype),
B: T.Buffer(new_shape, new_dtype if new_dtype else dtype),
A: T.Tensor((N,), dtype),
B: T.Tensor(new_shape, new_dtype if new_dtype else dtype),
):
with T.Kernel(1) as _:
A_viewed = T.view(A, new_shape, dtype=new_dtype)
......
......@@ -27,9 +27,9 @@ def matmul_ssr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
......@@ -145,9 +145,9 @@ def matmul_rsr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
......@@ -265,9 +265,9 @@ def matmul_rrr(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
......@@ -26,9 +26,9 @@ def matmul(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -164,9 +164,9 @@ def matmul_rs(
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......
......@@ -16,7 +16,7 @@ def _check(original, transformed):
def test_trival_pipeline():
@T.prim_func
def before(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
......@@ -39,7 +39,7 @@ def test_trival_pipeline():
C[tx, i] = B[tx, 0] + T.float32(1)
@T.prim_func
def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")) -> None:
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block():
T.reads(A[tx, 0])
......
......@@ -23,7 +23,7 @@ def _check(original, transformed):
def test_cluster_planning():
@T.prim_func
def before(A: T.Buffer((1024, 32), "float16"), B: T.Buffer((32, 1024), "float16"), C: T.Buffer(
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor(
(1024, 1024), "float16")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16")
......@@ -41,7 +41,7 @@ def test_cluster_planning():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Buffer((1024, 32), "float16"), B: T.Buffer((32, 1024), "float16"), C: T.Buffer(
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor(
(1024, 1024), "float16")):
T.func_attr({"clusterIdx.y": 2})
with T.Kernel(8, 8, threads=128) as (bx, by):
......
......@@ -16,7 +16,7 @@ def _check(original, transformed):
def test_let_binding():
@T.prim_func
def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")):
def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
for i in range(128):
for j in range(128):
with T.block("compute"):
......@@ -25,7 +25,7 @@ def test_let_binding():
B[i, j] = value
@T.prim_func
def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")):
def expected(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
for i in range(128):
for j in range(128):
with T.block("compute"):
......@@ -37,14 +37,14 @@ def test_let_binding():
def test_parallel_scope():
@T.prim_func
def before(A: T.Buffer((128,), "float32")):
def before(A: T.Tensor((128,), "float32")):
for i in T.Parallel(128):
with T.block("parallel"):
value = T.float32(1.0)
A[i] = value
@T.prim_func
def expected(A: T.Buffer((128,), "float32")):
def expected(A: T.Tensor((128,), "float32")):
for i in T.Parallel(128):
with T.block("parallel"):
A[i] = T.float32(1.0)
......
......@@ -19,7 +19,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
class Before:
@T.prim_func
def main(B: T.Buffer((K, N), dtype),):
def main(B: T.Tensor((K, N), dtype),):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
......@@ -42,7 +42,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
class After:
@T.prim_func
def main(B: T.Buffer((K, N), dtype),):
def main(B: T.Tensor((K, N), dtype),):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
......
......@@ -8,7 +8,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype = "float32"
@T.prim_func
def main(A: T.Buffer((M, N), dtype="float32"),):
def main(A: T.Tensor((M, N), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......@@ -16,7 +16,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared[tid, j] = A[tid + M_offset, j + N_offset]
@T.prim_func
def expected(A: T.Buffer((M, N), dtype="float32"),):
def expected(A: T.Tensor((M, N), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......
......@@ -9,7 +9,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
vec_len = 8
@T.prim_func
def main(A: T.Buffer((M, N, vec_len), dtype="float32"),):
def main(A: T.Tensor((M, N, vec_len), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
......@@ -18,7 +18,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
A_shared[tid, j, v] = A[tid, j, v]
@T.prim_func
def expected(A: T.Buffer((M, N, vec_len), dtype="float32"),):
def expected(A: T.Tensor((M, N, vec_len), dtype="float32"),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
tid = T.get_thread_binding()
......
......@@ -19,7 +19,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
class Before:
@T.prim_func
def main(B: T.Buffer((K, N), dtype),):
def main(B: T.Tensor((K, N), dtype),):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
......@@ -29,7 +29,7 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
class After:
@T.prim_func
def main(B: T.Buffer((K, N), dtype),):
def main(B: T.Tensor((K, N), dtype),):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared((block_K, block_N), dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
......
......@@ -186,7 +186,7 @@ def test_target_host_removed():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
def main(A: T.Tensor(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)})
T.evaluate(0)
......@@ -208,7 +208,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
def main(A: T.Tensor(1, "float32")):
T.func_attr({"target": T.target("llvm", host="llvm")})
before.subroutine(A.data)
......@@ -241,7 +241,7 @@ def test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
def main(A: T.Tensor(1, "float32")):
T.func_attr({"global_symbol": "main", "target": T.target("llvm", host="llvm")})
before.subroutine(A.data)
......@@ -271,10 +271,10 @@ def test_function_call_with_wrong_argument_count():
@T.prim_func
def func(
A: T.Buffer([16, 16], "int32"),
B: T.Buffer([16, 16], "int32"),
C: T.Buffer([16, 16], "int32"),
D: T.Buffer([16, 16], "int32"),
A: T.Tensor([16, 16], "int32"),
B: T.Tensor([16, 16], "int32"),
C: T.Tensor([16, 16], "int32"),
D: T.Tensor([16, 16], "int32"),
):
pass
......@@ -289,7 +289,7 @@ def test_function_call_with_wrong_type_code():
"""Type codes must be checked before accessing the arguments"""
@T.prim_func
def func(A: T.Buffer([16, 16], "int32")):
def func(A: T.Tensor([16, 16], "int32")):
pass
built = tvm.build(func, target="llvm")
......@@ -303,7 +303,7 @@ def test_function_call_with_null_data_pointer():
"""The data pointer must be checked before accessing the array"""
@T.prim_func
def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
def func(A: T.Tensor([16, 16], "int32"), B: T.Tensor([16, 16], "int32")):
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]
......@@ -323,7 +323,7 @@ def test_function_call_with_wrong_dimensionality():
"""The dimensionality must be checked before validating the shape"""
@T.prim_func
def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
def func(A: T.Tensor([16, 16], "int32"), B: T.Tensor([16, 16], "int32")):
for i, j in T.grid(16, 16):
B[i, j] = A[i, j]
......
......@@ -33,7 +33,7 @@ block_K = 32
def test_multi_version_buffer():
@T.prim_func
def before(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype)):
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
......@@ -66,7 +66,7 @@ def test_multi_version_buffer():
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
@T.prim_func
def after(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype)):
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment