Commit 8ad53855 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.ptr` and `T.Tensor` (#276)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix
parent 18f29277
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a_ptr: T.ptr,
b_ptr: T.ptr,
c_ptr: T.ptr,
m: T.int32,
n: T.int32,
k: T.int32,
):
A = T.Tensor.from_ptr(a_ptr, (m, k), dtype)
B = T.Tensor.from_ptr(b_ptr, (k, n), dtype)
C = T.Tensor.from_ptr(c_ptr, (m, n), accum_dtype)
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
jit_kernel = tl.compile(program, target="cuda", execution_backend="cython")
def ref_program(a, b):
return (a @ b.T).to(torch.float32)
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
jit_kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), M, N, K)
ref_c = (a @ b.T).to(map_torch_type(accum_dtype))
torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2)
def test_matmul():
run_matmul(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -2,8 +2,6 @@ import sys
import os
import ctypes
import warnings
import functools
import logging
from tqdm import tqdm
......@@ -53,29 +51,6 @@ def _init_logger():
_init_logger()
def deprecated(reason):
"""
This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return new_func
return decorator
logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO
......@@ -109,6 +84,7 @@ from .cache import cached # noqa: F401
from .utils import (
TensorSupplyType, # noqa: F401
deprecated, # noqa: F401
)
from .layout import (
Layout, # noqa: F401
......
......@@ -139,6 +139,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
ptr_map: Optional[Dict[int, str]] = None
# Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes
......@@ -183,6 +185,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype()
self.ptr_map = self._process_ptr_map()
self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device()
......@@ -211,6 +214,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self.cython_wrapper.set_ptr_map(self.ptr_map)
self._post_init()
@classmethod
......@@ -240,6 +244,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.buffer_dtype_map = adapter._process_buffer_dtype()
adapter.static_shape_map = adapter._process_static_shape()
adapter.ptr_map = adapter._process_ptr_map()
adapter.buffer_device_map = adapter._process_buffer_device()
adapter.verbose = verbose
......@@ -258,6 +263,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map)
adapter.cython_wrapper.set_ptr_map(adapter.ptr_map)
adapter._post_init()
return adapter
......@@ -275,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
......@@ -295,6 +303,20 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map
def _process_ptr_map(self) -> Dict[int, str]:
"""Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
func = self.prim_func
params = func.params
ptr_map = {}
for i, param in enumerate(params):
if param.dtype == 'handle':
ptr_map[i] = param.name
return ptr_map
def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]:
"""Extract information about static shapes from the TIR function.
......
......@@ -15,6 +15,7 @@ cdef class CythonKernelWrapper:
object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes
object ptr_map # Maps pointer arguments to their corresponding buffer indices
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
......@@ -54,6 +55,10 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map
return self
def set_ptr_map(self, ptr_map):
self.ptr_map = ptr_map
return self
def set_buffer_device_map(self, buffer_device_map):
self.buffer_device_map = buffer_device_map
return self
......@@ -109,6 +114,9 @@ cdef class CythonKernelWrapper:
call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr()))
elif isinstance(tensor_list[i], int):
# Dynamic symbolics which are passed as integer arguments
if i in self.ptr_map:
call_args.append(ctypes.c_void_p(tensor_list[i]))
else:
call_args.append(tensor_list[i])
elif isinstance(tensor_list[i], float):
call_args.append(ctypes.c_float(tensor_list[i]))
......@@ -119,6 +127,7 @@ cdef class CythonKernelWrapper:
# Check buffer device
for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
# Compare device types and indices separately to handle both string and torch.device objects
if (tensor_device.type != device.type or
......@@ -127,11 +136,13 @@ cdef class CythonKernelWrapper:
# Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
# Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
......
......@@ -155,6 +155,7 @@ class TLCUDASourceWrapper(object):
f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
......
......@@ -6,6 +6,7 @@ from typing import Optional
# tir script
from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import Buffer, Tensor, ptr # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Optional
from tvm import tir
from tvm.tir import Var, PrimExpr
from tvm.script.ir_builder.tir import buffer, handle, match_buffer
class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""
# Index via T.Buffer(...)
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
# Index via T.Buffer[...]
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> Buffer:
return match_buffer(ptr, shape, dtype=dtype)
class TensorProxy:
"""Buffer proxy class for constructing tir buffer."""
# Index via T.Tensor(...)
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
# Index via T.Tensor[...]
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer:
return match_buffer(ptr, shape, dtype=dtype)
Buffer = BufferProxy() # pylint: disable=invalid-name
Tensor = TensorProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
----------
dtype: str
The data type of the pointer.
storage_scope: str
The storage scope of the pointer.
is_size_var: bool
Whether or not to return a SizeVar instead of Var.
Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
......@@ -10,3 +10,4 @@ from .language import (
is_local, # noqa: F401
array_reduce, # noqa: F401
)
from .deprecated import deprecated # noqa: F401
def deprecated(
method_name: str,
new_method_name: str,
):
"""A decorator to indicate that a method is deprecated
Parameters
----------
method_name : str
The name of the method to deprecate
new_method_name : str
The name of the new method to use instead
"""
import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel
def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return _wrapper
return _deprecate
......@@ -3,6 +3,7 @@ from __future__ import annotations
from enum import Enum
import torch
from tvm.runtime import ndarray
from tvm import tir
from torch.utils.dlpack import to_dlpack
import numpy as np
......@@ -58,6 +59,14 @@ def get_tensor_supply(supply_type: TensorSupplyType):
f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape.")
# Check if with dynamic symbolic shape
for shape in param.shape:
if isinstance(shape, tir.Var):
raise ValueError(
f"TensorType must have a static shape, but got {shape}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto:
is_unsigned = param.is_unsigned()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment