Commit 927e50d9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Enhance alias to support blockwise memory load (#261)

* [Enhancement] Introduce caching control and frame management in TileLang

- Added cache control functions (`enable_cache`, `disable_cache`, `is_cache_enabled`) in `env.py` to manage kernel caching behavior.
- Updated `kernel_cache.py` to utilize the cache state, preventing unnecessary kernel compilation when caching is disabled.
- Introduced a new `frame.py` module to manage LetFrame instances, including a stack for variable-value mapping and enhanced frame management.
- Updated imports in various modules to accommodate new caching and frame functionalities, improving overall organization and clarity.

* [Refactor] Clean up and enhance caching and frame management in TileLang

- Added spacing for improved readability in `env.py` and `frame.py`.
- Refactored `LetFrame` class to enhance clarity in buffer region assignment.
- Ensured consistent formatting and organization across caching control and frame management functions.

* [Feature] Add matrix multiplication functionality in TileLang

- Introduced a new test file `test_tilelang_language_alias.py` that implements a matrix multiplication function using TileLang's primitives.
- The `matmul` function defines a kernel for performing tile-level GEMM operations, with support for customizable block sizes and data types.
- Added a `run_matmul` function to compile and execute the kernel, along with a test function to validate the implementation.
- Updated `gemm.py` to allow `tir.Buffer` or `tir.Var` as valid argument types for the `gemm` function, enhancing flexibility in argument handling.

* [Refactor] Improve formatting and readability in test_tilelang_language_alias.py

- Adjusted spacing and alignment in the `matmul` and `run_matmul` functions for better readability.
- Cleaned up unnecessary blank lines and ensured consistent formatting throughout the file.
- Enhanced overall code clarity without altering functionality.
parent 0430cfe7
Subproject commit ed1cb8dd61d81193ab33da03b9fcc9c4a04c3b60
Subproject commit 9ddb7a1753b7af7a0917fb1914563fddb9794879
import tilelang
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
X_shared = A_shared[:block_M, :block_K]
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], X_shared)
# Demonstrate parallelized copy from global to shared for B
T.copy(B[bx * block_N, ko * block_K], B_shared[:block_N, :block_K])
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(X_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(program, out_idx=[2], target="cuda")
kernel.run_once()
def test_matmul():
run_matmul(1024, 1024, 1024, 128, 128, 32)
if __name__ == "__main__":
test_matmul()
......@@ -79,6 +79,7 @@ def deprecated(reason):
logger = logging.getLogger(__name__)
from .env import SKIP_LOADING_TILELANG_SO
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
import tvm
import tvm._ffi.base
......
......@@ -13,7 +13,7 @@ import threading
import cloudpickle
import logging
from tilelang.env import TILELANG_CACHE_DIR
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "warpped_kernel.cu"
......@@ -89,6 +89,17 @@ class KernelCache:
Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache
"""
if not is_cache_enabled():
return JITKernel(
func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
)
key = self._generate_key(func, out_idx, execution_backend, args, target, target_host)
with self._lock: # TODO: use filelock
# Attempt to load from disk
......
......@@ -122,6 +122,27 @@ if os.environ.get("TL_TEMPLATE_PATH", None) is None:
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)
# Cache control
_ENABLE_TILELANG_KERNEL_CACHE = True # Default cache state
def enable_cache():
"""Enable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = True
def disable_cache():
"""Disable kernel caching globally."""
global _ENABLE_TILELANG_KERNEL_CACHE
_ENABLE_TILELANG_KERNEL_CACHE = False
def is_cache_enabled() -> bool:
"""Return current cache state."""
return _ENABLE_TILELANG_KERNEL_CACHE
__all__ = [
"CUTLASS_INCLUDE_DIR",
"TVM_PYTHON_PATH",
......@@ -129,4 +150,7 @@ __all__ = [
"TILELANG_TEMPLATE_PATH",
"CUDA_HOME",
"TILELANG_CACHE_DIR",
"enable_cache",
"disable_cache",
"is_cache_enabled",
]
......@@ -8,6 +8,7 @@ from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .kernel import (
Kernel, # noqa: F401
KernelLaunchFrame, # noqa: F401
......
"""The language interface for tl programs."""
from typing import Union, List, Optional
from tvm import tir
from tvm.script import tir as T
import tvm.ir
from tilelang import language as T
from tvm import ir, tir
def region(buffer: tir.BufferLoad, access_type: str, *args: tir.PrimExpr):
......@@ -33,9 +32,11 @@ def copy(
coalesced_width: Optional[int] = None,
):
if isinstance(src, tir.Buffer) and isinstance(dst, tir.Buffer):
tvm.ir.assert_structural_equal(src.shape, dst.shape)
ir.assert_structural_equal(src.shape, dst.shape)
def get_extent(data):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
return data.shape
elif isinstance(data, tir.BufferRegion):
......@@ -54,6 +55,8 @@ def copy(
raise TypeError("Can't deduce copy extents from args")
def _to_region(data, access_type):
if isinstance(data, tir.Var) and T.has_let_value(data):
data = T.get_let_value(data)
if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, tir.BufferRegion):
......
"""Override the LetFrame to print a message when entering the frame."""
from tvm._ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
from tvm.ir import Range
from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque
from typing import Optional
class FrameStack:
"""
A stack-like wrapper around a deque that provides push, pop, and top methods,
along with a var-value mapping functionality.
"""
def __init__(self):
self._stack = deque()
self._var_value_map = {}
def push(self, item):
"""Pushes an item onto the top of the stack."""
self._stack.append(item)
# Store the var-value mapping if it's a LetFrame
if hasattr(item, 'var') and hasattr(item, 'value'):
self._var_value_map[item.var] = item.value
def pop(self):
"""
Pops and returns the top of the stack, or returns None
if the stack is empty.
"""
if self._stack:
item = self._stack.pop()
# Clean up the var-value mapping if it's a LetFrame
if hasattr(item, 'var'):
self._var_value_map.pop(item.var, None)
return item
raise IndexError(f"{self.__class__.__name__} is empty")
def get_value(self, var):
"""Get the value associated with a variable."""
return self._var_value_map.get(var)
def has_value(self, var):
"""Check if a variable has an associated value."""
return var in self._var_value_map
def top(self):
"""
Returns the item on the top of the stack without removing it,
or None if the stack is empty.
"""
if self._stack:
return self._stack[-1]
raise IndexError(f"{self.__class__.__name__} is empty")
def __len__(self):
"""Returns the number of items in the stack."""
return len(self._stack)
def __bool__(self):
"""
Allows truthy checks on the stack object itself,
e.g., 'if stack: ...'
"""
return bool(self._stack)
# Global stack for LetFrame instances
_let_frame_stack = FrameStack()
@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
def __enter__(self) -> Var:
super().__enter__()
if isinstance(self.value, BufferLoad):
indices = self.value.indices
is_block_load = False
for index in indices[:-1]:
if DataType(index.dtype).lanes > 1:
is_block_load = True
break
if is_block_load:
self.value = BufferRegion(self.value.buffer,
[Range(x.base, x.lanes) for x in indices])
_let_frame_stack.push(self)
return self.var
def __exit__(self, ptype, value, trace):
if _let_frame_stack.top() is self:
_let_frame_stack.pop()
super().__exit__(ptype, value, trace)
@classmethod
def Current(cls) -> "LetFrame":
"""
Returns the topmost (current) LetFrame from the stack if it exists,
or raises IndexError if the stack is empty.
"""
return _let_frame_stack.top()
@staticmethod
def get_value(var: Var):
"""
Get the value associated with a variable.
Returns None if the variable is not found.
"""
return _let_frame_stack.get_value(var)
@staticmethod
def has_value(var: Var) -> bool:
"""
Check if a variable has an associated value.
"""
return _let_frame_stack.has_value(var)
def has_let_value(var: Var) -> bool:
"""
Check if a variable has an associated value in the let frame stack.
"""
return _let_frame_stack.has_value(var)
def get_let_value(var: Var) -> Optional[PrimExpr]:
"""
Get the value associated with a variable from the let frame stack.
Returns None if the variable is not found.
"""
return _let_frame_stack.get_value(var)
"""The language interface for tl programs."""
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union
def gemm(
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
A: Union[tir.Buffer, tir.Var],
B: Union[tir.Buffer, tir.Var],
C: Union[tir.Buffer, tir.Var],
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
......@@ -20,6 +22,16 @@ def gemm(
The number of k dimension that is packed into a single warp.
please ref to mfma macro generator for the detail information.
"""
def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
if isinstance(arg, tir.Var) and T.has_let_value(arg):
return T.get_let_value(arg).buffer
else:
return arg
A = legalize_arguments(A)
B = legalize_arguments(B)
C = legalize_arguments(C)
M = C.shape[0]
N = C.shape[1]
K = A.shape[0] if transpose_A else A.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment