Commit 927e50d9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Enhance alias to support blockwise memory load (#261)

* [Enhancement] Introduce caching control and frame management in TileLang

- Added cache control functions (`enable_cache`, `disable_cache`, `is_cache_enabled`) in `env.py` to manage kernel caching behavior.
- Updated `kernel_cache.py` to utilize the cache state, preventing unnecessary kernel compilation when caching is disabled.
- Introduced a new `frame.py` module to manage LetFrame instances, including a stack for variable-value mapping and enhanced frame management.
- Updated imports in various modules to accommodate new caching and frame functionalities, improving overall organization and clarity.

* [Refactor] Clean up and enhance caching and frame management in TileLang

- Added spacing for improved readability in `env.py` and `frame.py`.
- Refactored `LetFrame` class to enhance clarity in buffer region assignment.
- Ensured consistent formatting and organization across caching control and frame management functions.

* [Feature] Add matrix multiplication functionality in TileLang

- Introduced a new test file `test_tilelang_language_alias.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul` function defines a kernel for performing tile-level GEMM operations, with support for customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated `gemm.py` to allow `tir.Buffer` or `tir.Var` as valid argument types for the `gemm` function, enhancing flexibility in argument handling.

* [Refactor] Improve formatting and readability in test_tilelang_language_alias.py

- Adjusted spacing and alignment in the `matmul` and `run_matmul` functions for better readability.
- Cleaned up unnecessary blank lines and ensured consistent formatting throughout the file.
- Enhanced overall code clarity without altering functionality.
parent 0430cfe7
Subproject commit ed1cb8dd61d81193ab33da03b9fcc9c4a04c3b60 Subproject commit 9ddb7a1753b7af7a0917fb1914563fddb9794879
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
X_shared = A_shared[:block_M, :block_K]
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], X_shared)
# Demonstrate parallelized copy from global to shared for B
T.copy(B[bx * block_N, ko * block_K], B_shared[:block_N, :block_K])
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(X_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(program, out_idx=[2], target="cuda")
kernel.run_once()
def test_matmul():
run_matmul(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
test_matmul()
...@@ -79,6 +79,7 @@ def deprecated(reason): ...@@ -79,6 +79,7 @@ def deprecated(reason):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO from .env import SKIP_LOADING_TILELANG_SO
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
import tvm import tvm
import tvm._ffi.base import tvm._ffi.base
......
...@@ -13,7 +13,7 @@ import threading ...@@ -13,7 +13,7 @@ import threading
import cloudpickle import cloudpickle
import logging import logging
from tilelang.env import TILELANG_CACHE_DIR from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
KERNEL_PATH = "kernel.cu" KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "warpped_kernel.cu" WRAPPED_KERNEL_PATH = "warpped_kernel.cu"
...@@ -89,6 +89,17 @@ class KernelCache: ...@@ -89,6 +89,17 @@ class KernelCache:
Returns: Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache JITKernel: The compiled kernel, either freshly compiled or from cache
""" """
if not is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
key = self._generate_key(func, out_idx, execution_backend, args, target, target_host) key = self._generate_key(func, out_idx, execution_backend, args, target, target_host)
with self._lock: # TODO: use filelock with self._lock: # TODO: use filelock
# Attempt to load from disk # Attempt to load from disk
......
...@@ -122,6 +122,27 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None: ...@@ -122,6 +122,27 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None:
else: else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Cache control
_ENABLE_TILELANG_KERNEL_CACHE = True # Default cache state
def enable_cache():
"""Enable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = True
def disable_cache():
"""Disable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = False
def is_cache_enabled() -> bool:
"""Return current cache state."""
return _ENABLE_TILELANG_KERNEL_CACHE
__all__ = [ __all__ = [
"CUTLASS_INCLUDE_DIR", "CUTLASS_INCLUDE_DIR",
"TVM_PYTHON_PATH", "TVM_PYTHON_PATH",
...@@ -129,4 +150,7 @@ __all__ = [ ...@@ -129,4 +150,7 @@ __all__ = [
"TILELANG_TEMPLATE_PATH", "TILELANG_TEMPLATE_PATH",
"CUDA_HOME", "CUDA_HOME",
"TILELANG_CACHE_DIR", "TILELANG_CACHE_DIR",
"enable_cache",
"disable_cache",
"is_cache_enabled",
] ]
...@@ -8,6 +8,7 @@ from tvm.script.parser.tir import * ...@@ -8,6 +8,7 @@ from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401 from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401 from .pipeline import Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .kernel import ( from .kernel import (
Kernel, # noqa: F401 Kernel, # noqa: F401
KernelLaunchFrame, # noqa: F401 KernelLaunchFrame, # noqa: F401
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from typing import Union, List, Optional from typing import Union, List, Optional
from tvm import tir from tilelang import language as T
from tvm.script import tir as T from tvm import ir, tir
import tvm.ir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr): def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
...@@ -33,9 +32,11 @@ def copy( ...@@ -33,9 +32,11 @@ def copy(
coalesced_width: Optional[int] = None, coalesced_width: Optional[int] = None,
): ):
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer): if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
tvm.ir.assert_structural_equal(src.shape, dst.shape) ir.assert_structural_equal(src.shape, dst.shape)
def get_extent(data): def get_extent(data):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer): if isinstance(data, tir.Buffer):
return data.shape return data.shape
elif isinstance(data, tir.BufferRegion): elif isinstance(data, tir.BufferRegion):
...@@ -54,6 +55,8 @@ def copy( ...@@ -54,6 +55,8 @@ def copy(
raise TypeError("Can't deduce copy extents from args") raise TypeError("Can't deduce copy extents from args")
def _to_region(data, access_type): def _to_region(data, access_type):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer): if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type) return buffer_to_tile_region(data, access_type)
elif isinstance(data, tir.BufferRegion): elif isinstance(data, tir.BufferRegion):
......
"""Override the LetFrame to print a message when entering the frame."""
from tvm._ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
from tvm.ir import Range
from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque
from typing import Optional
class FrameStack:
"""
A stack-like wrapper around a deque that provides push, pop, and top methods,
along with a var-value mapping functionality.
"""
def __init__(self):
self._stack = deque()
self._var_value_map = {}
def push(self, item):
"""Pushes an item onto the top of the stack."""
self._stack.append(item)
# Store the var-value mapping if it's a LetFrame
if hasattr(item, 'var') and hasattr(item, 'value'):
self._var_value_map[item.var] = item.value
def pop(self):
"""
Pops and returns the top of the stack, or returns None
if the stack is empty.
"""
if self._stack:
item = self._stack.pop()
# Clean up the var-value mapping if it's a LetFrame
if hasattr(item, 'var'):
self._var_value_map.pop(item.var, None)
return item
raise IndexError(f"{self.__class__.__name__} is empty")
def get_value(self, var):
"""Get the value associated with a variable."""
return self._var_value_map.get(var)
def has_value(self, var):
"""Check if a variable has an associated value."""
return var in self._var_value_map
def top(self):
"""
Returns the item on the top of the stack without removing it,
or None if the stack is empty.
"""
if self._stack:
return self._stack[-1]
raise IndexError(f"{self.__class__.__name__} is empty")
def __len__(self):
"""Returns the number of items in the stack."""
return len(self._stack)
def __bool__(self):
"""
Allows truthy checks on the stack object itself,
e.g., 'if stack: ...'
"""
return bool(self._stack)
# Global stack for LetFrame instances
_let_frame_stack = FrameStack()
@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
def __enter__(self) -> Var:
super().__enter__()
if isinstance(self.value, BufferLoad):
indices = self.value.indices
is_block_load = False
for index in indices[:-1]:
if DataType(index.dtype).lanes > 1:
is_block_load = True
break
if is_block_load:
self.value = BufferRegion(self.value.buffer,
[Range(x.base, x.lanes) for x in indices])
_let_frame_stack.push(self)
return self.var
def __exit__(self, ptype, value, trace):
if _let_frame_stack.top() is self:
_let_frame_stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> "LetFrame":
"""
Returns the topmost (current) LetFrame from the stack if it exists,
or raises IndexError if the stack is empty.
"""
return _let_frame_stack.top()
@staticmethod
def get_value(var: Var):
"""
Get the value associated with a variable.
Returns None if the variable is not found.
"""
return _let_frame_stack.get_value(var)
@staticmethod
def has_value(var: Var) -> bool:
"""
Check if a variable has an associated value.
"""
return _let_frame_stack.has_value(var)
def has_let_value(var: Var) -> bool:
"""
Check if a variable has an associated value in the let frame stack.
"""
return _let_frame_stack.has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]:
"""
Get the value associated with a variable from the let frame stack.
Returns None if the variable is not found.
"""
return _let_frame_stack.get_value(var)
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir from tvm import tir
from typing import Union
def gemm( def gemm(
A: tir.Buffer, A: Union[tir.Buffer, tir.Var],
B: tir.Buffer, B: Union[tir.Buffer, tir.Var],
C: tir.Buffer, C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False, transpose_A: bool = False,
transpose_B: bool = False, transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square, policy: GemmWarpPolicy = GemmWarpPolicy.Square,
...@@ -20,6 +22,16 @@ def gemm( ...@@ -20,6 +22,16 @@ def gemm(
The number of k dimension that is packed into a single warp. The number of k dimension that is packed into a single warp.
please ref to mfma macro generator for the detail information. please ref to mfma macro generator for the detail information.
""" """
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
else:
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
M = C.shape[0] M = C.shape[0]
N = C.shape[1] N = C.shape[1]
K = A.shape[0] if transpose_A else A.shape[1] K = A.shape[0] if transpose_A else A.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment