Commit bf8a6fc1 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Deprecated `T.Buffer` as arguments and rename related calls into `T.Tensor` (#281)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix

* [Refactor] Update tensor creation in matrix multiplication test

- Replaced `T.Tensor.from_ptr` with `T.make_tensor` in `matmul_test` for improved clarity and consistency.
- Updated imports in `__init__.py` to include `make_tensor`.
- Added `make_tensor` function in `proxy.py` to streamline tensor creation from pointers.

* [Refactor] Update tensor definitions across multiple files

- Replaced instances of `T.Tensor` with updated tensor definitions in various benchmark and example files to enhance consistency and clarity.
- Adjusted tensor shapes and types in functions related to matrix multiplication, attention mechanisms, and other operations.
- Improved documentation in README and example files to reflect changes in tensor usage.

* lint fix

* [Refactor] Update tensor types in attention and matrix multiplication examples

- Replaced instances of `T.Tensor` with `T.SharedTensor` and `T.FragmentTensor` in various attention and matrix multiplication functions to improve consistency and clarity.
- Adjusted tensor definitions in benchmark and example files to align with the new tensor types.
- Enhanced the overall structure and readability of the code by standardizing tensor usage across multiple files.

* lint fix

* [Refactor] Update tensor types in GEMM example and test files

- Replaced instances of `T.Tensor` with `T.LocalTensor` and `T.Buffer` in the GEMM example and related test functions to improve consistency and clarity.
- Enhanced the overall structure of the code by standardizing tensor usage across multiple files, aligning with recent updates in tensor definitions.

* [Refactor] Update tensor usage in customize.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `reshape` and `view` functions to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the file.

* [Refactor] Update tensor types in test_tilelang_transform_annotate_device_regions.py

- Replaced instances of `T.Tensor` with `T.Buffer` in the `before` and `expected` methods of the `TestAnnotateThreadExtent` and `TestAnnotateDeviceScope` classes to enhance consistency with recent tensor definitions.
- Improved code clarity by standardizing buffer usage across the test file.

* [Refactor] Update tensor types to SharedBuffer and FragmentBuffer

- Replaced instances of `T.SharedTensor` and `T.FragmentTensor` with `T.SharedBuffer` and `T.FragmentBuffer` across multiple benchmark, example, and test files to enhance consistency with recent tensor definitions.
- Improved code clarity and structure by standardizing buffer usage in attention and matrix multiplication functions.

* [Refactor] Introduce Tensor alias for Buffer in proxy.py

- Added a new alias `Tensor` for `Buffer` in `proxy.py` to facilitate JIT compilation, ensuring that inputs and outputs are mapped with `torch.Tensor`.
- This change enhances clarity and consistency in tensor usage across the codebase.
parent 73d2c62e
......@@ -21,7 +21,7 @@ def _check(original, transformed):
def test_simple_pipeline():
@T.prim_func
def before(A: T.Buffer((1024, 32), "float32"), B: T.Buffer((32, 1024), "float32"), C: T.Buffer(
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor(
(1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
......@@ -39,7 +39,7 @@ def test_simple_pipeline():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Buffer((1024, 32), "float32"), B: T.Buffer((32, 1024), "float32"), C: T.Buffer(
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor(
(1024, 1024), "float32")):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
......
......@@ -11,11 +11,11 @@ def modify(
@T.prim_func
def main(
A: T.Buffer((64, 64)),
B: T.Buffer((64, 64)),
C: T.Buffer((64, 64)),
D: T.Buffer((64, 64)),
bias: T.Buffer((64, 64)),
A: T.Tensor((64, 64)),
B: T.Tensor((64, 64)),
C: T.Tensor((64, 64)),
D: T.Tensor((64, 64)),
bias: T.Tensor((64, 64)),
):
if with_B:
if with_bias:
......
......@@ -19,7 +19,7 @@ def test_vectorize_loop(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
def main(A: T.Tensor((16,), "float32")):
for j in T.vectorized(0, extent):
A[j] = 1
......@@ -27,7 +27,7 @@ def test_vectorize_loop(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
def main(A: T.Tensor((16,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)
with tvm.target.Target(target):
......@@ -63,7 +63,7 @@ def test_vectorize_vector_scalable_error():
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for j in T.vectorized(T.vscale() * 4):
A[j * 4:j * 4 + 4] = T.Broadcast(T.float32(1), 4)
......@@ -80,7 +80,7 @@ def test_vectorize_vector_scalable_error2():
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32xvscalex4")):
def main(A: T.Tensor((25,), "float32xvscalex4")):
for j in T.vectorized(4):
A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
......@@ -96,7 +96,7 @@ def test_vectorize_vector_scalable_error3():
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for j in T.vectorized(4):
A[j * T.vscale() * 4:j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast(
T.float32(1),
......@@ -115,7 +115,7 @@ def test_vectorize_vector_scalable_error4():
class Module:
@T.prim_func(private=True)
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for j in T.vectorized(T.vscale() * 4):
A[j * T.vscale() * 4:j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast(
T.float32(1),
......@@ -135,7 +135,7 @@ def test_vectorize_with_if(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
def main(A: T.Tensor((25,), "float32"), n: T.int32, x: T.int32):
for i in T.vectorized(extent):
if x < n:
A[i] = A[i] + T.float32(1)
......@@ -147,7 +147,7 @@ def test_vectorize_with_if(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
def main(A: T.Tensor((25,), "float32"), n: T.int32, x: T.int32):
if x < n:
A[T.Ramp(0, 1,
extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
......@@ -180,7 +180,7 @@ def test_vectorize_let(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for i in T.vectorized(extent):
v = A[i] + T.float32(1)
A[i] = v + T.float32(2)
......@@ -189,7 +189,7 @@ def test_vectorize_let(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)
......@@ -246,7 +246,7 @@ def test_vectorize_if_then_else_scalarize(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for i in T.vectorized(extent):
A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i])
......@@ -254,7 +254,7 @@ def test_vectorize_if_then_else_scalarize(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32")):
for i_s in range(extent):
A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])
......@@ -271,7 +271,7 @@ def test_vectorize_if_then_else_vector(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32):
def main(A: T.Tensor((25,), "float32"), n: T.int32):
for i in range(n):
for j in T.vectorized(extent):
A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0)
......@@ -280,7 +280,7 @@ def test_vectorize_if_then_else_vector(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32):
def main(A: T.Tensor((25,), "float32"), n: T.int32):
for i in range(n):
A[T.Ramp(i * extent, 1, extent)] = T.if_then_else(i > 0,
A[T.Ramp(i * extent, 1, extent)],
......@@ -359,7 +359,7 @@ def test_vectorize_with_reinterpret(extent, vec_str, target):
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
def main(A: T.Tensor((16,), "int32"), B: T.Tensor((16,), "float32")):
for i in T.vectorized(0, extent):
B[i] = T.reinterpret("float32", A[i])
......@@ -367,7 +367,7 @@ def test_vectorize_with_reinterpret(extent, vec_str, target):
class After:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
def main(A: T.Tensor((16,), "int32"), B: T.Tensor((16,), "float32")):
B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
......@@ -403,7 +403,7 @@ def test_vectorize_binary(op, extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32"), B: T.Tensor((25,), "float32")):
for j in T.vectorized(extent):
A[j] = op(T.float32(3), B[j])
......@@ -411,7 +411,7 @@ def test_vectorize_binary(op, extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32"), B: T.Tensor((25,), "float32")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
......@@ -428,7 +428,7 @@ def test_vectorize_logical(op, extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
def main(A: T.Tensor((25,), "bool"), B: T.Tensor((25,), "bool")):
for j in T.vectorized(extent):
A[j] = op(T.bool(1), B[j])
......@@ -436,7 +436,7 @@ def test_vectorize_logical(op, extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
def main(A: T.Tensor((25,), "bool"), B: T.Tensor((25,), "bool")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
......@@ -452,7 +452,7 @@ def test_vectorize_select(extent, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32"), B: T.Tensor((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.Select(T.bool(True), A[j], B[j])
......@@ -460,7 +460,7 @@ def test_vectorize_select(extent, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "float32"), B: T.Tensor((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Select(
T.Broadcast(T.bool(True), extent),
A[T.Ramp(0, 1, extent)],
......@@ -483,7 +483,7 @@ def test_vectorize_cast(extent, vec_str, target):
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "int32"), B: T.Tensor((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.Cast("int32", B[j])
......@@ -491,7 +491,7 @@ def test_vectorize_cast(extent, vec_str, target):
class After:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
def main(A: T.Tensor((25,), "int32"), B: T.Tensor((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
......@@ -506,7 +506,7 @@ def test_illegal_extent():
class Mod:
@T.prim_func
def main(A: T.Buffer((25,), "int32")):
def main(A: T.Tensor((25,), "int32")):
n = T.Var("n", dtype="int32")
for j in T.vectorized(n):
A[j] = 3
......@@ -523,7 +523,7 @@ def test_illegal_vscale_in_non_sve_compilation():
class Mod:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
def main(A: T.Tensor((16,), "float32")):
for j in T.vectorized(0, 4 * T.vscale()):
A[j] = 13
......
......@@ -34,7 +34,7 @@ block_K = 32
def test_warp_specialized():
@T.prim_func
def before(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype)):
def before(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
......@@ -68,7 +68,7 @@ def test_warp_specialized():
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
@T.prim_func
def after(A: T.Buffer((M, K), dtype), B: T.Buffer((K, N), dtype)):
def after(A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 256)
......
......@@ -8,9 +8,9 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
......
......@@ -6,7 +6,15 @@ from typing import Optional
# tir script
from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import Buffer, Tensor, ptr # noqa: F401
from .proxy import (
ptr, # noqa: F401
make_tensor, # noqa: F401
Buffer, # noqa: F401
Tensor, # noqa: F401
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
)
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
......
"""The language interface for tl programs."""
from tvm.script import tir as T
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer
from typing import List, Union
......
......@@ -127,12 +127,12 @@ def macro(*args, hygienic: bool = True) -> Callable:
@T.prim_func
def use1(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def use1(A: T.Tensor((1024,), "int32"), B: T.Tensor((), "int32")) -> None:
for x_value in T.serial(10):
static_capture(A, B) ### Produces B[()] = A[128]
@T.prim_func
def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def use2(A: T.Tensor((1024,), "int32"), B: T.Tensor((), "int32")) -> None:
for x_value in T.serial(10):
dynamic_capture(A, B) ### Produces B[()] = A[x_value]
```
......@@ -182,7 +182,7 @@ class BufferProxy:
axis_separators=axis_separators,
)
@deprecated("T.Buffer[...]", "T.Buffer(...)")
@deprecated("T.Tensor[...]", "T.Tensor(...)")
def __getitem__(self, keys) -> Buffer:
if not isinstance(keys, tuple):
return self(keys)
......
......@@ -6,12 +6,14 @@ from typing import Optional
from tvm import tir
from tvm.tir import Var, PrimExpr
from tvm.script.ir_builder.tir import buffer, handle, match_buffer
from tilelang.utils import deprecated
class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""
# Index via T.Buffer(...)
@deprecated("T.Buffer(...)", "T.Tensor(...)")
def __call__(
self,
shape,
......@@ -39,6 +41,7 @@ class BufferProxy:
)
# Index via T.Buffer[...]
@deprecated("T.Buffer[...]", "T.Tensor(...)")
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
......@@ -50,10 +53,12 @@ class BufferProxy:
return match_buffer(ptr, shape, dtype=dtype)
class TensorProxy:
"""Buffer proxy class for constructing tir buffer."""
class BaseTensorProxy:
"""Base proxy class for tensor types with configurable defaults"""
default_scope = "global"
default_align = 0
default_offset_factor = 0
# Index via T.Tensor(...)
def __call__(
self,
shape,
......@@ -61,12 +66,17 @@ class TensorProxy:
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
scope=None, # Changed to None to use class default
align=None,
offset_factor=None,
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
# Use class defaults if not specified
scope = scope or self.default_scope
align = align or self.default_align
offset_factor = offset_factor or self.default_offset_factor
return buffer(
shape,
dtype=dtype,
......@@ -80,20 +90,41 @@ class TensorProxy:
axis_separators=axis_separators,
)
# Index via T.Tensor[...]
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
return self(*keys)
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer:
return match_buffer(ptr, shape, dtype=dtype)
class TensorProxy(BaseTensorProxy):
"""Main tensor proxy with default global scope"""
class FragmentBufferProxy(BaseTensorProxy):
default_scope = "local.fragment"
class SharedBufferProxy(BaseTensorProxy):
default_scope = "shared.dyn"
class LocalBufferProxy(BaseTensorProxy):
default_scope = "local"
Buffer = BufferProxy() # pylint: disable=invalid-name
# Tensor is an alias for Buffer
# Because when user do jit compile, the input and output will
# be mapped with torch.Tensor.
Tensor = TensorProxy() # pylint: disable=invalid-name
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None,
......@@ -119,3 +150,7 @@ def ptr(dtype: Optional[str] = None,
The new tir.Var with type handle or casted expression with type handle.
"""
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
def make_tensor(ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer:
return Tensor.from_ptr(ptr, shape, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment