"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bb8114a4a65145a1489436f98392007e12935ae9"
Commit 8ad53855 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Introduce `T.ptr` and `T.Tensor` (#276)

* [Refactor] Improve flash attention example and layout comparison logic

- Removed unnecessary annotation for `lse_local_split` in the flash attention example to streamline the code.
- Updated the handling of `lse_local_split` to utilize parallel processing for better performance.
- Refactored kernel compilation and profiling logic to enhance clarity and maintainability in the flash attention example.
- Added a condition in `FragmentNode::IsEqual` to handle broadcast cases, improving the robustness of layout comparisons.

* lint fix

* [Enhancement] Add support for shared memory scope in Fill operation

- Introduced handling for `shared.dyn` and `shared` memory scopes in the Fill operation.
- Implemented parallel operation and layout inference for improved performance in shared memory scenarios.
- Updated thread loop partitioning and vectorization logic to accommodate new memory scope handling.

* [Refactor] Remove deprecated decorator and enhance Cython kernel handling

- Removed the deprecated decorator from the main module and added a new implementation in the utils module for better organization.
- Introduced a pointer map in the Cython kernel adapter to manage pointer arguments, improving runtime shape resolution.
- Updated the Cython kernel wrapper to utilize the new pointer map for handling kernel arguments.
- Enhanced error checking in the tensor utility functions to ensure static shapes are enforced.
- Added a new proxy module for buffer and tensor handling, streamlining the interface for TIR programs.

* [Feature] Add matrix multiplication test and kernel implementation

- Introduced a new test file `test_tilelang_language_ptr.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul_test` function defines a kernel for performing tile-level GEMM operations with customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated the `proxy.py` file to enhance type handling for buffer and tensor proxies, ensuring compatibility with TIR programs.
- Minor formatting improvements in `deprecated.py` for better readability.

* lint fix
parent 18f29277
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
a_ptr: T.ptr,
b_ptr: T.ptr,
c_ptr: T.ptr,
m: T.int32,
n: T.int32,
k: T.int32,
):
A = T.Tensor.from_ptr(a_ptr, (m, k), dtype)
B = T.Tensor.from_ptr(b_ptr, (k, n), dtype)
C = T.Tensor.from_ptr(c_ptr, (m, n), accum_dtype)
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
jit_kernel = tl.compile(program, target="cuda", execution_backend="cython")
def ref_program(a, b):
return (a @ b.T).to(torch.float32)
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
jit_kernel(a.data_ptr(), b.data_ptr(), c.data_ptr(), M, N, K)
ref_c = (a @ b.T).to(map_torch_type(accum_dtype))
torch.testing.assert_close(c, ref_c, atol=1e-2, rtol=1e-2)
def test_matmul():
run_matmul(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -2,8 +2,6 @@ import sys ...@@ -2,8 +2,6 @@ import sys
import os import os
import ctypes import ctypes
import warnings
import functools
import logging import logging
from tqdm import tqdm from tqdm import tqdm
...@@ -53,29 +51,6 @@ def _init_logger(): ...@@ -53,29 +51,6 @@ def _init_logger():
_init_logger() _init_logger()
def deprecated(reason):
"""
This is a decorator which can be used to mark functions as deprecated.
It will result in a warning being emitted when the function is used.
"""
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return new_func
return decorator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO from .env import SKIP_LOADING_TILELANG_SO
...@@ -109,6 +84,7 @@ from .cache import cached # noqa: F401 ...@@ -109,6 +84,7 @@ from .cache import cached # noqa: F401
from .utils import ( from .utils import (
TensorSupplyType, # noqa: F401 TensorSupplyType, # noqa: F401
deprecated, # noqa: F401
) )
from .layout import ( from .layout import (
Layout, # noqa: F401 Layout, # noqa: F401
......
...@@ -139,6 +139,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -139,6 +139,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
ptr_map: Optional[Dict[int, str]] = None
# Maps buffer variables to their corresponding dtypes # Maps buffer variables to their corresponding dtypes
buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
# Maps buffer variables to their corresponding static shapes # Maps buffer variables to their corresponding static shapes
...@@ -183,6 +185,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -183,6 +185,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.dynamic_symbolic_map = self._process_dynamic_symbolic() self.dynamic_symbolic_map = self._process_dynamic_symbolic()
self.buffer_dtype_map = self._process_buffer_dtype() self.buffer_dtype_map = self._process_buffer_dtype()
self.ptr_map = self._process_ptr_map()
self.static_shape_map = self._process_static_shape() self.static_shape_map = self._process_static_shape()
self.buffer_device_map = self._process_buffer_device() self.buffer_device_map = self._process_buffer_device()
...@@ -211,6 +214,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -211,6 +214,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map) self.cython_wrapper.set_buffer_dtype_map(self.buffer_dtype_map)
self.cython_wrapper.set_static_shape_map(self.static_shape_map) self.cython_wrapper.set_static_shape_map(self.static_shape_map)
self.cython_wrapper.set_buffer_device_map(self.buffer_device_map) self.cython_wrapper.set_buffer_device_map(self.buffer_device_map)
self.cython_wrapper.set_ptr_map(self.ptr_map)
self._post_init() self._post_init()
@classmethod @classmethod
...@@ -240,6 +244,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -240,6 +244,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic() adapter.dynamic_symbolic_map = adapter._process_dynamic_symbolic()
adapter.buffer_dtype_map = adapter._process_buffer_dtype() adapter.buffer_dtype_map = adapter._process_buffer_dtype()
adapter.static_shape_map = adapter._process_static_shape() adapter.static_shape_map = adapter._process_static_shape()
adapter.ptr_map = adapter._process_ptr_map()
adapter.buffer_device_map = adapter._process_buffer_device() adapter.buffer_device_map = adapter._process_buffer_device()
adapter.verbose = verbose adapter.verbose = verbose
...@@ -258,6 +263,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -258,6 +263,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map) adapter.cython_wrapper.set_buffer_dtype_map(adapter.buffer_dtype_map)
adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map) adapter.cython_wrapper.set_static_shape_map(adapter.static_shape_map)
adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map) adapter.cython_wrapper.set_buffer_device_map(adapter.buffer_device_map)
adapter.cython_wrapper.set_ptr_map(adapter.ptr_map)
adapter._post_init() adapter._post_init()
return adapter return adapter
...@@ -275,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -275,7 +282,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
if param in buffer_map: if param in buffer_map:
buffer = buffer_map[param] buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape): for j, shape in enumerate(buffer.shape):
if isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map): if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (i, j) dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map return dynamic_symbolic_map
...@@ -295,6 +303,20 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -295,6 +303,20 @@ class CythonKernelAdapter(BaseKernelAdapter):
buffer_dtype_map[name] = (i, map_torch_type(dtype)) buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map return buffer_dtype_map
def _process_ptr_map(self) -> Dict[int, str]:
"""Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
for runtime shape resolution.
"""
func = self.prim_func
params = func.params
ptr_map = {}
for i, param in enumerate(params):
if param.dtype == 'handle':
ptr_map[i] = param.name
return ptr_map
def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]: def _process_static_shape(self) -> Dict[tir.Var, List[Tuple[int, int]]]:
"""Extract information about static shapes from the TIR function. """Extract information about static shapes from the TIR function.
......
...@@ -15,6 +15,7 @@ cdef class CythonKernelWrapper: ...@@ -15,6 +15,7 @@ cdef class CythonKernelWrapper:
object buffer_device_map # Maps buffer variables to their corresponding devices object buffer_device_map # Maps buffer variables to their corresponding devices
object buffer_dtype_map # Maps buffer variables to their corresponding dtypes object buffer_dtype_map # Maps buffer variables to their corresponding dtypes
object static_shape_map # Maps buffer variables to their corresponding static shapes object static_shape_map # Maps buffer variables to their corresponding static shapes
object ptr_map # Maps pointer arguments to their corresponding buffer indices
list result_idx # Indices of output tensors in the params list list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs) list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel object lib # Reference to the compiled library containing the kernel
...@@ -54,6 +55,10 @@ cdef class CythonKernelWrapper: ...@@ -54,6 +55,10 @@ cdef class CythonKernelWrapper:
self.static_shape_map = static_shape_map self.static_shape_map = static_shape_map
return self return self
def set_ptr_map(self, ptr_map):
self.ptr_map = ptr_map
return self
def set_buffer_device_map(self, buffer_device_map): def set_buffer_device_map(self, buffer_device_map):
self.buffer_device_map = buffer_device_map self.buffer_device_map = buffer_device_map
return self return self
...@@ -109,7 +114,10 @@ cdef class CythonKernelWrapper: ...@@ -109,7 +114,10 @@ cdef class CythonKernelWrapper:
call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr())) call_args.append(ctypes.c_void_p(tensor_list[i].data_ptr()))
elif isinstance(tensor_list[i], int): elif isinstance(tensor_list[i], int):
# Dynamic symbolics which are passed as integer arguments # Dynamic symbolics which are passed as integer arguments
call_args.append(tensor_list[i]) if i in self.ptr_map:
call_args.append(ctypes.c_void_p(tensor_list[i]))
else:
call_args.append(tensor_list[i])
elif isinstance(tensor_list[i], float): elif isinstance(tensor_list[i], float):
call_args.append(ctypes.c_float(tensor_list[i])) call_args.append(ctypes.c_float(tensor_list[i]))
elif isinstance(tensor_list[i], bool): elif isinstance(tensor_list[i], bool):
...@@ -119,27 +127,30 @@ cdef class CythonKernelWrapper: ...@@ -119,27 +127,30 @@ cdef class CythonKernelWrapper:
# Check buffer device # Check buffer device
for param, (buffer_idx, device) in self.buffer_device_map.items(): for param, (buffer_idx, device) in self.buffer_device_map.items():
tensor_device = tensor_list[buffer_idx].device if isinstance(tensor_list[buffer_idx], torch.Tensor):
# Compare device types and indices separately to handle both string and torch.device objects tensor_device = tensor_list[buffer_idx].device
if (tensor_device.type != device.type or # Compare device types and indices separately to handle both string and torch.device objects
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)): if (tensor_device.type != device.type or
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}") (tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
# Check buffer dtype map # Check buffer dtype map
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items(): for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if tensor_list[buffer_idx].dtype != torch_dtype: if isinstance(tensor_list[buffer_idx], torch.Tensor):
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}") if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
# Check static shape map # Check static shape map
for param, (buffer_idx, shape_list) in self.static_shape_map.items(): for param, (buffer_idx, shape_list) in self.static_shape_map.items():
for shape_idx, shape in shape_list: if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].shape[shape_idx] != shape: for shape_idx, shape in shape_list:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}") if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
# Add dynamic dimension values to kernel arguments # Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx]) call_args.append(tensor_list[buffer_idx].shape[shape_idx])
# Add CUDA stream to kernel arguments # Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream)) call_args.append(ctypes.c_void_p(stream))
......
...@@ -155,7 +155,8 @@ class TLCUDASourceWrapper(object): ...@@ -155,7 +155,8 @@ class TLCUDASourceWrapper(object):
f"Parameter {param} is not in the buffer map of the primary function.") f"Parameter {param} is not in the buffer map of the primary function.")
# Add dynamic symbols as integer arguments # Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set: for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"}) if dyn_sym not in [arg["name"] for arg in function_args]:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
......
...@@ -6,6 +6,7 @@ from typing import Optional ...@@ -6,6 +6,7 @@ from typing import Optional
# tir script # tir script
from tvm.script.parser.tir import * from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import Buffer, Tensor, ptr # noqa: F401
from .parallel import Parallel # noqa: F401 from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401 from .pipeline import Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Optional
from tvm import tir
from tvm.tir import Var, PrimExpr
from tvm.script.ir_builder.tir import buffer, handle, match_buffer
class BufferProxy:
"""Buffer proxy class for constructing tir buffer."""
# Index via T.Buffer(...)
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
# Index via T.Buffer[...]
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> Buffer:
return match_buffer(ptr, shape, dtype=dtype)
class TensorProxy:
"""Buffer proxy class for constructing tir buffer."""
# Index via T.Tensor(...)
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> tir.Buffer:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
# Index via T.Tensor[...]
def __getitem__(self, keys) -> tir.Buffer:
if not isinstance(keys, tuple):
return self(keys)
if len(keys) >= 2 and not isinstance(keys[1], str):
return self(keys)
return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member
def from_ptr(self, ptr: Var, shape: tuple[PrimExpr, ...], dtype: str = "float32") -> tir.Buffer:
return match_buffer(ptr, shape, dtype=dtype)
Buffer = BufferProxy() # pylint: disable=invalid-name
Tensor = TensorProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
"""Create a TIR var that represents a pointer.
Parameters
----------
dtype: str
The data type of the pointer.
storage_scope: str
The storage scope of the pointer.
is_size_var: bool
Whether or not to return a SizeVar instead of Var.
Returns
-------
res : PrimExpr
The new tir.Var with type handle or casted expression with type handle.
"""
return handle(dtype=dtype, storage_scope=storage_scope, is_size_var=is_size_var)
...@@ -10,3 +10,4 @@ from .language import ( ...@@ -10,3 +10,4 @@ from .language import (
is_local, # noqa: F401 is_local, # noqa: F401
array_reduce, # noqa: F401 array_reduce, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401
def deprecated(
method_name: str,
new_method_name: str,
):
"""A decorator to indicate that a method is deprecated
Parameters
----------
method_name : str
The name of the method to deprecate
new_method_name : str
The name of the new method to use instead
"""
import functools # pylint: disable=import-outside-toplevel
import warnings # pylint: disable=import-outside-toplevel
def _deprecate(func):
@functools.wraps(func)
def _wrapper(*args, **kwargs):
warnings.warn(
f"{method_name} is deprecated, use {new_method_name} instead",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return _wrapper
return _deprecate
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from enum import Enum from enum import Enum
import torch import torch
from tvm.runtime import ndarray from tvm.runtime import ndarray
from tvm import tir
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
import numpy as np import numpy as np
...@@ -58,6 +59,14 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -58,6 +59,14 @@ def get_tensor_supply(supply_type: TensorSupplyType):
f"TensorType must have a shape, but got {type(param)}, " f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape.") "likely you are trying to generate a random tensor with a dynamic symbolic shape.")
# Check if with dynamic symbolic shape
for shape in param.shape:
if isinstance(shape, tir.Var):
raise ValueError(
f"TensorType must have a static shape, but got {shape}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
shape = list(map(int, param.shape)) shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
is_unsigned = param.is_unsigned() is_unsigned = param.is_unsigned()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment