import tilelang import tilelang.language as T import tilelang.testing import torch def test_tensor_annot_mul(): @tilelang.jit def example_tensor_annot(): n = T.symbolic("n") @T.prim_func def kernel( A: T.Tensor((n * 4,), T.int32), ): with T.Kernel(1) as _: for i in range(n * 4): A[i] = 0 return kernel ker = example_tensor_annot() A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_add(): @tilelang.jit def example_tensor_annot(): n = T.symbolic("n") @T.prim_func def kernel( A: T.Tensor((n + 1,), T.int32), ): with T.Kernel(1) as _: for i in range(n + 1): A[i] = 0 return kernel ker = example_tensor_annot() A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) def test_tensor_annot_mul_add(): @tilelang.jit def example_tensor_annot(): n = T.symbolic("n") @T.prim_func def kernel( A: T.Tensor((n * 3 + 1,), T.int32), ): with T.Kernel(1) as _: for i in range(n * 3 + 1): A[i] = 0 return kernel ker = example_tensor_annot() A = torch.arange(16, dtype=torch.int32, device="cuda") ker(A) expected = torch.zeros(16, dtype=torch.int32, device="cuda") assert torch.equal(A, expected) if __name__ == "__main__": tilelang.testing.main()