Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
......@@ -6,7 +6,8 @@ import tilelang as tl
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.autotuner import autotune
import itertools
......@@ -48,22 +49,22 @@ def tl_matmul(
enable_rasteration=False,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
# chunk = 32 if in_dtype == "float16" else 64
# chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
block_M = block_row_warps * warp_row_tiles
......@@ -103,12 +104,11 @@ def tl_matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -116,10 +116,12 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)
# Improve L2 Cache
T.use_swizzle(panel_size=10, enable=enable_rasteration)
......@@ -127,7 +129,6 @@ def tl_matmul(
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
......@@ -137,7 +138,6 @@ def tl_matmul(
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
......@@ -194,9 +194,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -223,7 +223,6 @@ def get_configs(args, kwargs):
for config in configs:
print(config)
else:
iter_params = dict(
block_row_warps=[1, 2, 4],
block_col_warps=[1, 2, 4],
......@@ -233,9 +232,7 @@ def get_configs(args, kwargs):
stage=[0, 2],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -247,14 +244,16 @@ def get_configs(args, kwargs):
ref_prog=ref_program,
skip_check=True,
)
@tl.jit(out_idx=[2],)
@tl.jit(
out_idx=[2],
)
def matmul(
M,
N,
K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
with_roller=False,
block_row_warps=None,
block_col_warps=None,
......@@ -291,19 +290,14 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--with_roller",
type=bool,
default=False,
help="Whether to use roller to deduce search spaces")
parser.add_argument(
"--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
parser.add_argument("--with_roller", type=bool, default=False, help="Whether to use roller to deduce search spaces")
parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "int8"], help="Input data type")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
in_dtype = args.dtype
out_dtype = "float32" if in_dtype == "int8" else "float16"
accum_dtype = "float32" if in_dtype == "int8" else "float16"
in_dtype = T.dtype(args.dtype)
out_dtype = T.float32 if in_dtype == T.int8 else T.float16
accum_dtype = T.float32 if in_dtype == T.int8 else T.float16
with_roller = args.with_roller
with_roller = True
# Compute total floating-point operations
......
......@@ -9,7 +9,7 @@ import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_metadata_layout
from tilelang.layout import make_cutlass_metadata_layout
# Configure logger
logger = logging.getLogger(__name__)
......@@ -70,7 +70,8 @@ def get_configs(M, N, K):
thread_num,
policy,
enable_rasterization,
))
)
)
configs = [
{
......@@ -81,12 +82,13 @@ def get_configs(M, N, K):
"thread_num": c[4],
"policy": c[5],
"enable_rasterization": c[6], # keep param name for backward-compat
} for c in _configs
}
for c in _configs
]
return configs
def matmul_sp(M, N, K, accum_dtype):
def matmul_sp(M, N, K, in_dtype, accum_dtype):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
......@@ -126,7 +128,9 @@ def matmul_sp(M, N, K, accum_dtype):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def kernel(
block_M=None,
block_N=None,
......@@ -161,15 +165,14 @@ def matmul_sp(M, N, K, accum_dtype):
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
A_sparse: T.Tensor((M, K // 2), in_dtype),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), in_dtype),
C: T.Tensor((M, N), accum_dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -183,13 +186,11 @@ def matmul_sp(M, N, K, accum_dtype):
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
A_shared = T.alloc_shared((block_M, block_K // 2), in_dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_K, block_N), dtype)
B_shared = T.alloc_shared((block_K, block_N), in_dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation
......@@ -202,14 +203,12 @@ def matmul_sp(M, N, K, accum_dtype):
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
})
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype=in_dtype, block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=in_dtype, block_k=block_K),
}
)
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
......@@ -220,7 +219,7 @@ def matmul_sp(M, N, K, accum_dtype):
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared
T.gemm_sp(
T.gemm_sp_v2(
A_shared,
E_shared,
B_shared,
......@@ -244,18 +243,13 @@ if __name__ == "__main__":
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
choices=['cutlass', 'cusparselt'],
choices=["cutlass", "cusparselt"],
default=None,
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported",
)
args = parser.parse_args()
......@@ -268,7 +262,7 @@ if __name__ == "__main__":
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K, args.accum_dtype)
best_result = matmul_sp(M, N, K, T.float16, args.accum_dtype)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
......@@ -277,7 +271,8 @@ if __name__ == "__main__":
if args.bench_torch_sparse is not None:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
if args.bench_torch_sparse == 'cutlass':
if args.bench_torch_sparse == "cutlass":
SparseSemiStructuredTensor._FORCE_CUTLASS = True
A_sp = to_sparse_semi_structured(A, transposed=False)
torch_sparse_latency = do_bench(lambda: A_sp @ B)
......@@ -288,8 +283,6 @@ if __name__ == "__main__":
print(f"Best config: {best_config}")
if args.bench_torch_sparse is not None:
print(
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
)
print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
import argparse
import itertools
import torch
import logging
import tilelang
import tilelang.language as T
......@@ -62,9 +63,9 @@ def get_configs(args, kwargs):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float32,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -99,12 +100,11 @@ def get_configs(args, kwargs):
block_K=[64, 128],
num_stages=[0, 1, 2, 3],
thread_num=[128, 256],
k_pack=[1, 2],
policy=[T.GemmWarpPolicy.Square],
enable_rasteration=[True, False],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
return configs
......@@ -114,7 +114,9 @@ def get_configs(args, kwargs):
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
@jit(
out_idx=[2],
)
def matmul(
M,
N,
......@@ -125,6 +127,7 @@ def matmul(
block_K=None,
num_stages=None,
thread_num=None,
k_pack=None,
policy=None,
enable_rasteration=None,
):
......@@ -156,14 +159,14 @@ def matmul(
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float8_e4m3"
accum_dtype = "float"
dtype = T.float8_e4m3fnuz if torch.version.hip is not None else T.float8_e4m3fn
accum_dtype = T.float32
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
......@@ -178,7 +181,6 @@ def matmul(
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
......@@ -210,6 +212,7 @@ def matmul(
C_local,
transpose_B=True,
policy=policy,
k_pack=k_pack,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
......
......@@ -3,12 +3,15 @@
set(TVM_BUILD_FROM_SOURCE TRUE)
set(TVM_SOURCE ${CMAKE_SOURCE_DIR}/3rdparty/tvm)
if(DEFINED $ENV{TVM_ROOT})
if(DEFINED ENV{TVM_ROOT})
if(EXISTS $ENV{TVM_ROOT}/cmake/config.cmake)
set(TVM_SOURCE $ENV{TVM_ROOT})
message(STATUS "Using TVM_ROOT from environment variable: ${TVM_SOURCE}")
endif()
endif()
message(STATUS "Using TVM source: ${TVM_SOURCE}")
set(TVM_INCLUDES
${TVM_SOURCE}/include
${TVM_SOURCE}/src
......
if(Z3_FOUND)
return()
endif()
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])"
OUTPUT_VARIABLE Z3_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE Z3_PYTHON_RESULT
)
if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "")
message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.")
endif()
message("-- Find Z3 in path: ${Z3_PATH}")
find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include)
find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64)
message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}")
message("-- Found Z3 library: ${Z3_LIBRARY}")
add_library(z3::libz3 SHARED IMPORTED GLOBAL)
set_target_properties(z3::libz3
PROPERTIES
IMPORTED_LOCATION ${Z3_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR}
)
if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
message(FATAL_ERROR "Could not find Z3 library or include directory")
endif()
set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_FOUND TRUE)
......@@ -9,23 +9,43 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential git wget \
libgtest-dev libprotobuf-dev protobuf-compiler libgflags-dev libsqlite3-dev llvm-dev \
rocm-dev rocm-libs hip-dev hipblas-dev rocblas-dev \
&& apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
ENV PATH="/opt/conda/bin:${PATH}"
ENV LIBGL_ALWAYS_INDIRECT=1
ENV USE_ROCM=1
ENV USE_CUDA=0
ENV ROCM_HOME=/opt/rocm
ENV HIP_PLATFORM=amd
ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN conda run -n py_3.10 conda install pip cmake -y && \
conda run -n py_3.10 conda install -c conda-forge libstdcxx-ng=12 -y && \
conda clean --all
RUN apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
RUN apt-get update && apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev && \
apt-get clean autoclean && rm -rf /var/lib/apt/lists/{cache,log} /tmp/* /var/tmp/*
RUN git clone https://github.com/tile-ai/tilelang.git --recursive -b main tilelang && \
conda run -n py_3.10 bash -c "cd tilelang && USE_ROCM=1 pip install -e . -v"
# Copy local tilelang directory instead of cloning from git
# Build from tilelang root: docker build -f docker/Dockerfile.rocm -t mi300:latest .
COPY . /root/tilelang
RUN conda init bash
RUN mv /opt/conda/envs/py_3.10/compiler_compat /opt/conda/envs/py_3.10/compiler_compat.bak || true && \
conda run -n py_3.10 bash -c "export USE_ROCM=1 USE_CUDA=0 && pip install 'numpy<2.0' --force-reinstall" && \
conda run -n py_3.10 bash -c "cd /root/tilelang && \
# Backup and modify pyproject.toml to remove torch from dependencies \
cp pyproject.toml pyproject.toml.bak && \
sed -i '/^[[:space:]]*\"torch/d' pyproject.toml && \
# Install tilelang with all dependencies except torch \
USE_ROCM=1 USE_CUDA=0 pip install -e . -v && \
# Restore original pyproject.toml \
mv pyproject.toml.bak pyproject.toml"
RUN conda init bash && \
echo "conda activate py_3.10" >> /root/.bashrc
SHELL ["/bin/bash", "-l", "-c"]
CMD ["bash", "-c", "source ~/.bashrc && conda activate py_3.10 && exec bash"]
\ No newline at end of file
ENTRYPOINT ["/bin/bash", "--login", "-i"]
/* Reduce the displayed size of the sidebar logo in Furo */
.sidebar-logo {
max-height: 125px;
width: auto;
}
/* Optional: keep container from growing too tall due to spacing */
.sidebar-logo-container {
line-height: 0;
}
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
# Tensor Checks (Host-Side Auto-Validation)
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
## Why Host-Side Checks
- ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
- Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
- Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
## How To Inspect Host Source
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
```python
print(matmul_relu_kernel.get_host_source())
```
---
## What The Host Checks
### 1) Argument count and pointer kind
- `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message.
- Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error.
### 2) Tensor checks (per tensor, after nullability decision)
- Nullability
- If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`.
- If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`.
- Rank (`ndim`)
- Runtime `ndim` must equal the compile-time rank.
- Data type (`dtype`)
- Match the triple `(code, bits, lanes)` with tolerance:
- `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`.
- `float8_e5m2`: accept `e5m2`, `e5m2fnuz`.
- `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match).
- For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped.
- Shape
- Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
- Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
- Strides
- If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality.
- Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`).
- `byte_offset`
- Must be 0 (non-zero raises an error) to keep addressing simple and aligned.
- Device info
- Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.
- When multiple tensors participate, assert that `device_id` matches across them.
- Data pointer
- Must be non-NULL when the tensor is required to be non-null by the nullability rule.
### 3) Scalar checks
- `T.int*` family: require integer; error: `Expect arg[i] to be int`.
- `T.bool`: require boolean; error: `Expect arg[i] to be boolean`.
---
## Shapes and Symbolic Equations: Linear Solving
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
```python
@T.prim_func
def main(
A: T.Tensor((m,), dtype),
B: T.Tensor((m + n,), dtype),
C: T.Tensor((n * k,), dtype),
):
...
```
This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime.
---
## Nullability Rules and Examples
Which tensors may be NULL?
- Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
- Examples:
1) Must be non-NULL (used)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
A[0] = 1
```
Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`.
2) Still must be non-NULL (constant-true branch)
```python
some_cond: bool = True
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
```
3) Nullable (constant-false branch, statically unreachable)
```python
some_cond: bool = False
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
```
4) Must be non-NULL (runtime condition)
```python
@T.prim_func
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
if some_cond:
A[0] = 1
```
Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable.
---
## Device Type Codes (DLPack)
Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`.
Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors.
---
## Common Error Examples (What you’ll see)
- Argument count mismatch (num_args)
- Trigger: missing/extra argument
- Error: `<kernel>: num_args should be N; expected: <num_args>, got: N`
- Pointer-typed argument expected
- Trigger: scalar passed where a tensor is expected
- Error: `<kernel>: Expect arg[i] to be pointer`
- Rank (ndim) mismatch
- Trigger: runtime rank differs from compile-time rank
- Error: `<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim`
- Dtype mismatch
- Trigger: dtype not equal to the compiled dtype and not within the tolerance set
- Error: `<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype`
- Shape constraint violation
- Trigger: a dimension doesn’t match a constant/symbol binding
- Error: `Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>`
- Strides check failed (e.g., non-contiguous layout)
- Trigger: transposed/sliced tensors that violate expected strides
- Error: `Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>`
- Device type mismatch
- Trigger: calling a CUDA kernel with CPU tensors, etc.
- Error: `<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...`
- Device id mismatch
- Trigger: mixing tensors from different GPUs
- Error: `Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...`
- NULL data pointer
- Trigger: tensor required to be non-null has a NULL data pointer
- Error: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`
- Scalar type mismatch
- Trigger: passing float to `T.int32`, or non-boolean to `T.bool`
- Error: `<kernel>: Expect arg[i] to be int/boolean`
---
## Troubleshooting Tips
- Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields.
- Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.
- Align devices: ensure all participating tensors share the same `device_type` and `device_id`.
- Align dtype: use `.to(<dtype>)` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance.
- Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
---
## FAQ
- Can I disable the checks?
- Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
- Is the overhead noticeable?
- The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
---
## Reference Example (Matmul + ReLU)
```python
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
# For debugging, print the host source
print(matmul_relu_kernel.get_host_source())
```
The host will insert all checks described above for this example.
---
## Quick Error Reference (Short List)
- Argument count
- Trigger: missing/extra args; Error: `num_args should be N; expected: <num_args>, got: N`.
- Pointer kind
- Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`.
- Rank (ndim)
- Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`.
- Dtype
- Trigger: mismatch and not tolerated; Error: `dtype ... expected to be <dtype>`.
- Shape
- Trigger: constant/symbol binding violated; Error: `shape[i] ... == <expected>`.
- Strides
- Trigger: layout mismatch; Error: `strides[j] ... == <expected>`.
- Device type
- Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`.
- Device id
- Trigger: tensors on different GPUs; Error: `device_id ... == ...`.
- Data pointer
- Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`.
- Scalar types
- Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`.
---
## Host Error Troubleshooting (Minimal Repros)
Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with:
```python
# Convention:
# A: float16 [M, K]
# B: float16 [K, N]
# C: float16 [M, N]
# Target: CUDA (device_type=2)
fn = matmul_relu_kernel # your compiled function
M = N = K = 1024
```
Adjust dtype/device if your kernel differs.
### 0. Tip: print the host source
```python
print(fn.get_host_source())
```
### 1. num_args mismatch
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
# Missing C
fn(A, B)
```
Expected: `<kernel>: num_args should be 3; expected: <num_args>, got: 3`.
Fix: pass all arguments per the signature.
### 2. Expect pointer (tensor) but got scalar
```python
import torch
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(1, B, C)
```
Expected: `<kernel>: Expect arg[0] to be pointer`.
Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).
### 3. ndim mismatch
```python
import torch
A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim`.
Fix: ensure runtime rank equals compiled rank.
### 4. dtype mismatch
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `<kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype`.
Fix: `A = A.to(torch.float16)` or create with the correct dtype.
### 5. Shape constant/symbol mismatch
```python
import torch
A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>`.
Fix: satisfy linear constraints and constants across tensors.
### 6. Strides check failure (non-contiguous)
```python
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
A_nc = A.t() # transpose -> non-contiguous
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A_nc, B, C)
```
Expected: `Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1`.
Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel.
### 7. device_type mismatch
```python
import torch
A = torch.empty((M, K), device='cpu', dtype=torch.float16)
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
fn(A, B, C) # CUDA-targeted kernel
```
Expected: `<kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ...`.
Fix: move tensors to the CUDA device.
### 8. device_id mismatch (multi-GPU)
```python
import torch
A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
fn(A, B, C)
```
Expected: `Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ...`.
Fix: place all tensors on the same GPU (e.g., `cuda:0`).
### 9. NULL data pointer (advanced)
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this.
Expected: `<kernel>.<name> is expected to have non-NULL data pointer, but got NULL`.
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
### 10. Scalar type mismatch (int / bool)
```python
import tilelang.language as T
@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0)
scalar_check(1.0, True) # x is float -> Expect arg[0] to be int
scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean
```
Fix: pass correct scalar types, e.g., `scalar_check(1, True)`.
---
## Closing Notes
- Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
- For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.
# General information about the project.
project = "Tile Language <br>"
project = "TileLang <br>"
author = "Tile Lang Contributors"
copyright = f"2025-2025, {author}"
......@@ -20,33 +20,27 @@ extensions = [
"autoapi.extension",
]
autoapi_type = 'python'
autoapi_dirs = ['../tilelang']
autoapi_type = "python"
autoapi_dirs = ["../tilelang"]
autoapi_options = [
'members',
'undoc-members',
'show-inheritance',
'show-module-summary',
'special-members',
"members",
"undoc-members",
"show-inheritance",
"show-module-summary",
"special-members",
]
autoapi_keep_files = False # Useful for debugging the generated rst files
autoapi_generate_api_docs = True
autodoc_typehints = 'description'
autodoc_typehints = "description"
autoapi_ignore = ["*language/ast*", "*version*", "*libinfo*", "*parser*"]
source_suffix = {
'.rst': 'restructuredtext',
'.md': 'markdown',
}
source_suffix = {".rst": "restructuredtext", ".md": "markdown"}
myst_enable_extensions = [
"colon_fence",
"deflist",
]
myst_enable_extensions = ["colon_fence", "deflist"]
redirects = {"get_started/try_out": "../index.html#getting-started"}
......@@ -62,13 +56,11 @@ todo_include_todos = False
html_theme = "furo"
templates_path = []
html_static_path = ["_static"]
footer_copyright = "© 2025-2025 Tile Language"
html_css_files = ["custom.css"]
footer_copyright = "© 2025-2026 TileLang"
footer_note = " "
html_theme_options = {
"light_logo": "img/logo-row.svg",
"dark_logo": "img/logo-row.svg",
}
html_theme_options = {"light_logo": "img/logo-v2.png", "dark_logo": "img/logo-v2.png"}
header_links = [
("Home", "https://github.com/tile-ai/tilelang"),
......
......@@ -24,7 +24,7 @@ Please note that this tutorial does not delve deeply into the design principles
## Elementwise add in TileLang
```python
def elementwise_add(N, threads=256, dtype="bfloat16"):
def elementwise_add(N, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -43,7 +43,7 @@ Those familiar with CUDA programming might wonder where `threadIdx` fits into th
The program can be compiled using the following code:
```python
program = elementwise_add(1024, threads=256, dtype="bfloat16")
program = elementwise_add(1024, threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
Launching the kernel is straightforward, just call it directly like a function:
......@@ -89,7 +89,7 @@ def elementwise_add(
In the compilation process above, a fixed shape was used. However, in practical usage, we often want the kernel to support dynamic shapes. So, how can we compile a kernel in TileLang to handle dynamic shapes? In TileLang, we can replace the target size with a dynamic symbolic value, making the dimension dynamic. The following example illustrates this:
```python
program = elementwise_add(T.dynamic("N"), threads=256, dtype="bfloat16")
program = elementwise_add(T.dynamic("N"), threads=256, dtype=T.bfloat16)
kernel = tilelang.compile(program, out_idx=-1, target="cuda", execution_backend="cython")
```
......@@ -102,7 +102,7 @@ TileLang automatically incorporates boundary-checking conditions; however, this
When compiling the example below, let's set `N` to 2047:
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -176,7 +176,7 @@ While TileLang incorporates various optimizations for the aforementioned case, i
In such scenarios, explicitly specifying the number of elements computed per thread can help "guide" TileLang's code generation process, leading to implementations that are more closely aligned with the intended design.
```python
def elementwise_add(N, num_per_thread=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, num_per_thread=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......@@ -212,7 +212,7 @@ Aha, this CUDA code aligns closely with conventional programming practices, maki
But what happens if we provide additional hints to TileLang? For instance, by explicitly specifying register copies using the `T.copy(...)` operation. The example below demonstrates a vector addition implementation. Unlike the previous examples, this code explicitly loads data into registers before performing computations.
```python
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype="bfloat16"):
def elementwise_add(N, NUM_ELE_PER_THREAD=8, threads=256, dtype=T.bfloat16):
@T.prim_func
def main(A: T.Tensor((N), dtype), B: T.Tensor((N), dtype), C: T.Tensor((N), dtype)):
......
# Sparse Matrix-Matrix Multiplication with Tile Library
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/botbw">botbw</a>
</div>
:::{warning}
This document is still **experimental** and may be incomplete.
This feature is still **experimental** and need further optimization.
Suggestions and improvements are highly encouraged—please submit a PR!
:::
:::{tip}
It's suggested to go through `docs/deeplearning_operators/matmul.md` first.
Example code can be found at `examples/gemm_sp`.
:::
## Structured sparsity in the NVIDIA Ampere architecture
Since the Ampere architecture (sm80 and above), sparsity support has been integrated into Tensor Cores. This allows a 2:4 (or 1:2 for 32-bit data types) semi-structured matrix to be compressed into its non-zero values along with associated metadata, which can then be fed into the Tensor Core. This enables up to **2x throughput** compared to the equivalent dense computation.
:::{warning}
This tutorial primarily focuses on CUDA, as this feature is not yet supported on ROCm. However, AMD provides a similar capability in the matrix cores of GPUs such as the MI300X.
:::
```{figure} ../_static/img/sparse_mma_storage_example.png
:align: center
Figure: Sparse MMA storage example (from PTX doc)
```
## Compress a dense tensor
To utilize sparse Tensor Cores, a dense tensor must first be **compressed** into its non-zero values along with the corresponding metadata.
Both `PyTorch` and `vLLM` use `CUTLASS` as their computation backend (see references [here](https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu#L47) and [here](https://github.com/vllm-project/vllm/blob/a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh#L116)), leveraging `CUTLASS`’s built-in compressor (or reimplementing it in `PyTorch`).
A set of **CUTLASS-compatible** compressors is provided in `tilelang.utils.sparse`, where a dense tensor—along with other required arguments (e.g., block_K for sm90, transpose options)—can be passed in to perform the compression.
```python
from tilelang.utils.sparse import compress
A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
```
Here, `A_sparse` contains all the non-zero elements of `A`, while `E` stores the corresponding metadata (indexing information) required to reconstruct the original sparse pattern.
> NOTE: When using CUTLASS compressor, there is no naive position correspondence between the positions in `A_sparse`/`A` and `E`. (i.e. the 4-element group at [n, k] doesn't match the 4-bit metadata at [n, k] if you consider metadata as int4 tensor)
The metadata is reordered internally to optimize memory access patterns (e.g., for ldsm instructions and vectorized loads).
For more information, see **A note on `gemm_sp` and `gemm_sp_v2`**.
## `T.gemm_sp` with CUTLASS's compressor
:::{warning}
It is strongly recommended to use T.gemm_sp_v2 due to its greater flexibility and faster compilation time.
:::
A 2:4 sparse GEMM kernel is similar to its dense counterpart, except that it also requires handling the associated metadata.
Check comments in below kernel code for required modification.
```python
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = 'int32' if is_8_bit else 'int16'
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] # Calculate shape for given datatypes
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), metadata_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), metadata_dtype) # Allocate smem for metadata
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({ # Annotate reordered cutlass metadata layout
E:
make_cutlass_metadata_layout(E, mma_dtype=in_dtype, arch="8.0"),
E_shared:
make_cutlass_metadata_layout(
E_shared, mma_dtype=in_dtype, arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B) # Call gemm_sp with non-zero values and metadata
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
```
Under the hood, `gemm_sp` invokes templates adapted from `CUTLASS`, and a compatible metadata layout must be specified using `T.annotate_layout`.
## `T.gemm_sp_v2` with a custom compressor
To migrate to `gemm_sp_v2`, simply replace occurrences of `gemm_sp`.
Unlike `gemm_sp`, `gemm_sp_v2` can operate without `T.annotate_layout`, and it also supports user-defined layouts and compressors.
The metadata is stored in a `(u)int8`/`(u)int16`/`(u)int32` tensor, where **each 4-bit chunk represents two 2-bit indices** of non-zero elements within four consecutive elements. Here, we start with an `int16` example, which is the **default dtype** for `bf16` and `fp16` on Ampere GPUs.
Suppose we have the following row vector:
```python
t = tensor([[0, 7, 0, 3], [1, 5, 0, 0], [0, 0, 2, 4], [9, 0, 9, 0]], dtype=torch.float16).flatten()
```
The non-zero elements and their corresponding indices are:
```python
t_sp = tensor([[7, 3], [1, 5], [2, 4], [9, 9]], dtype=torch.float16).flatten()
indices = tensor([[1, 3], [0, 1], [2, 3], [0, 2]], dtype=torch.float16).flatten()
```
The corresponding uint16 metadata is:
```python
# metadata_bits = tensor([0b1101, 0b0100, 0b1110, 0b1000])
# Note: storage uses little-endian order: tensor(0b1000111001001101, dtype=torch.int16)
# Note: the above code is not runnable in python as the interpreter won't take the binary
# as 2's complement
metadata_int16 = tensor(-29107)
```
You can decode an int16 metadata tensor using the following utility:
```python
def decode_metadata(meta: torch.Tensor) -> torch.Tensor:
assert meta.dtype is torch.int16
groups_per_meta = 16 // 4
out = []
for g in range(groups_per_meta):
group_bits = (meta >> (g * 4)) & 0xF
idx0 = group_bits & 0x3
idx1 = (group_bits >> 2) & 0x3
out.append(torch.stack([idx0, idx1], dim=-1))
return torch.concat(out, dim=-1).view(meta.shape[0], -1)
```
The compressor can be implement at either `PyTorch`/`NumPy` level or kernel level.
For example, `PyTorch` provides an Ampere compressor [here](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L47-L179). Note that in this implementation, a [permutation](https://github.com/pytorch/pytorch/blob/267d0197bfca0232488d51dd1ff735d619adc2cf/torch/sparse/_semi_structured_conversions.py#L173-L175) is applied to match CUTLASS’s metadata layout. If you do not annotate a metadata layout when using `gemm_sp_v2`, your compressor should replicate the same behavior as the PyTorch example—but without using the `_calculate_meta_reordering_scatter_offsets` function.
If you want to use a custom metadata layout in your kernel, one approach is to define the layout in `TileLang` and then apply the same layout to both your compressor kernel and the matmul_sp kernel.
```python
@tilelang.jit(out_idx=[1, 2], pass_configs={
tilelang.PassConfigKey.TIR_DISABLE_VECTORIZE: True,
})
def compress_kernel(M, K, block_M, block_K, dtype, use_cutlass_layout):
e_factor, e_dtype = ARCH_INFO["8.0"]
e_K = K // e_factor
elem, group = 2, 4
assert M % block_M == 0, "M must be divisible by block_M"
assert K % block_K == 0, "K must be divisible by block_K"
assert K % e_factor == 0, "K must be divisible by e_factor"
assert block_K % e_factor == 0, "block_K must be divisible by e_factor"
@T.prim_func
def kernel(
A: T.Tensor((M, K), dtype),
A_sp: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, e_K), e_dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=block_M) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
A_sp_shared = T.alloc_shared((block_M, block_K // 2), dtype)
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
if use_cutlass_layout: # NOTE: Make sure compressor metadata layout
T.annotate_layout({ # is same with your computation kernel
E:
make_cutlass_metadata_layout(
E, mma_dtype="float16", arch="8.0", block_k=block_K),
E_shared:
make_cutlass_metadata_layout(
E_shared,
mma_dtype="float16",
arch="8.0",
block_k=block_K),
})
T.clear(A_sp_shared)
T.clear(E_shared)
non_zero_cnt = T.alloc_local((1, ), dtype="uint8")
non_zero_elt_log_idx = T.alloc_local((elem, ), dtype="uint8")
T.copy(A[bx * block_M, by * block_K], A_shared)
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
non_zero_elt_log_idx[non_zero_cnt[0]] = i
A_sp_shared[tm, a_k // 2 + non_zero_cnt[0]] = val
non_zero_cnt[0] += 1
if non_zero_cnt[0] == 1 and non_zero_elt_log_idx[0] == 3:
non_zero_elt_log_idx[0] = 0
non_zero_elt_log_idx[1] = 3
A_sp_shared[tm, a_k // 2 + 1] = A_sp_shared[tm, a_k // 2]
A_sp_shared[tm, a_k // 2] = 0.0
elif non_zero_cnt[0] == 1:
A_sp_shared[tm, a_k // 2 + 1] = 0
non_zero_elt_log_idx[1] = 3
for i in T.serial(elem):
val = non_zero_elt_log_idx[i]
E_shared[tm, a_k // e_factor] |= T.shift_left(val, 4 * (g_i % (e_factor // group)) + 2 * i)
T.copy(A_sp_shared, A_sp[bx * block_M, by * block_K // 2])
T.copy(E_shared, E[bx * block_M, by * block_K // e_factor])
return kernel
```
## A note on `gemm_sp` and `gemm_sp_v2`
Initially, `T.gemm_sp` followed the same design as `T.gemm`, lowering to a `CUTLASS` template. This inherently requires metadata to be reordered offline following a predetermined layout.
However, fixing a specific layout introduces several potential issues:
1. Painful debugging experience: Debugging a failed kernel becomes difficult due to the reordered indexing, including permutations and swizzling.
2. Limited flexibility: For example, concatenating two compressed tensors, such as `A_sparse_0` and `A_sparse_1`, into a new `A_sparse` makes sense. However, concatenating their metadata `E_0` and `E_1` may not be valid unless the layout allows it mathematically.
3. Alignment requirements: `CUTLASS` enforces strict alignment checks, and many hyperparameter configurations can lead to compilation errors. (For reference, sm8x was implemented in `CUTLASS 2`.)
`T.gemm_sp_v2` was designed to address these limitations, following the approach of `T.gemm_v2`. It lowers directly to PTX, removing the need for a fixed metadata layout.
\ No newline at end of file
......@@ -8,25 +8,25 @@
- **Python Version**: >= 3.8
- **CUDA Version**: 12.0 <= CUDA < 13
The easiest way to install **tile-lang** is directly from PyPI using pip. To install the latest version, run the following command in your terminal:
The easiest way to install tilelang is directly from PyPI using pip. To install the latest version, run the following command in your terminal:
```bash
pip install tilelang
```
Alternatively, you may choose to install **tile-lang** using prebuilt packages available on the Release Page:
Alternatively, you may choose to install tilelang using prebuilt packages available on the Release Page:
```bash
pip install tilelang-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```
To install the latest version of **tile-lang** from the GitHub repository, you can run the following command:
To install the latest version of tilelang from the GitHub repository, you can run the following command:
```bash
pip install git+https://github.com/tile-ai/tilelang.git
```
After installing **tile-lang**, you can verify the installation by running:
After installing tilelang, you can verify the installation by running:
```bash
python -c "import tilelang; print(tilelang.__version__)"
......@@ -40,18 +40,18 @@ python -c "import tilelang; print(tilelang.__version__)"
- **Python Version**: >= 3.8
- **CUDA Version**: >= 10.0
```bash
docker run -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```
If you prefer Docker, please skip to the [Install Using Docker](#install-using-docker) section. This section focuses on building from source on a native Linux environment.
To build and install **tile-lang** directly from source, follow these steps. This process requires certain pre-requisites from Apache TVM, which can be installed on Ubuntu/Debian-based systems using the following commands:
First, install the OS-level prerequisites on Ubuntu/Debian-based systems using the following commands:
```bash
apt-get update
apt-get install -y python3 python3-dev python3-setuptools gcc zlib1g-dev build-essential cmake libedit-dev
```
After installing the prerequisites, you can clone the **tile-lang** repository and install it using pip:
Then, clone the tilelang repository and install it using pip. The `-v` flag enables verbose output during the build process.
> **Note**: Use the `--recursive` flag to include necessary submodules. Tilelang currently depends on a customized version of TVM, which is included as a submodule. If you prefer [Building with Existing TVM Installation](#using-existing-tvm), you can skip cloning the TVM submodule (but still need other dependencies).
```bash
git clone --recursive https://github.com/tile-ai/tilelang.git
......@@ -59,13 +59,19 @@ cd tilelang
pip install . -v
```
If you want to install **tile-lang** in development mode, you can run the following command:
If you want to install tilelang in development mode, you can use the `-e` flag so that any changes to the Python files will be reflected immediately without reinstallation.
```bash
pip install -e . -v
```
If you prefer to work directly from the source tree via `PYTHONPATH`, make sure the native extension is built first:
> **Note**: changes to C++ files require rebuilding the tilelang C++ library. See [Faster Rebuild for Developers](#faster-rebuild-for-developers) below. A default `build` directory will be created if you use `pip install`, so you can also directly run `make` in the `build` directory to rebuild it as [Working from Source via PYTHONPATH](#working-from-source-via-pythonpath) suggested below.
(working-from-source-via-pythonpath)=
### Working from Source via `PYTHONPATH` (Recommended for Developers)
If you prefer to work directly from the source tree via `PYTHONPATH` instead of using pip, make sure the native extension (`libtilelang.so`) is built first:
```bash
mkdir -p build
......@@ -73,6 +79,14 @@ cd build
cmake .. -DUSE_CUDA=ON
make -j
```
We also recommend using `ninja` to speed up compilation:
```bash
cmake .. -DUSE_CUDA=ON -G Ninja
ninja
```
Then add the repository root to `PYTHONPATH` before importing `tilelang`, for example:
```bash
......@@ -85,17 +99,23 @@ Some useful CMake options you can toggle while configuring:
- `-DUSE_ROCM=ON` selects ROCm support when building on AMD GPUs.
- `-DNO_VERSION_LABEL=ON` disables the backend/git suffix in `tilelang.__version__`.
We currently provide four methods to install **tile-lang**:
(using-existing-tvm)=
### Building with Customized TVM Path
If you already have a TVM codebase, use the `TVM_ROOT` environment variable to specify the location of your existing TVM repository when building tilelang:
```bash
TVM_ROOT=<your-tvm-repo> pip install . -v
```
1. [Install Using Docker](#install-method-1) (Recommended)
2. [Install from Source (using the bundled TVM submodule)](#install-method-2)
3. [Install from Source (using your own TVM installation)](#install-method-3)
> **Note**: This will still rebuild the TVM-related libraries (stored in `TL_LIBS`). And this method often leads to some path issues. Check `env.py` to see some environment variables which are not set properly.
(install-method-1)=
(install-using-docker)=
### Method 1: Install Using Docker (Recommended)
## Install Using Docker
For users who prefer a containerized environment with all dependencies pre-configured, **tile-lang** provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems and is the **recommended approach** for most users.
For users who prefer a containerized environment with all dependencies pre-configured, tilelang provides Docker images for different CUDA versions. This method is particularly useful for ensuring consistent environments across different systems.
**Prerequisites:**
- Docker installed on your system
......@@ -142,82 +162,17 @@ docker run -itd \
- `--name tilelang_b200`: Assigns a name to the container for easy management
- `/bin/zsh`: Uses zsh as the default shell
4. **Access the Container**:
4. **Access the Container and Verify Installation**:
```bash
docker exec -it tilelang_b200 /bin/zsh
```
5. **Verify Installation**:
Once inside the container, verify that **tile-lang** is working correctly:
```bash
# Inside the container:
python -c "import tilelang; print(tilelang.__version__)"
```
You can now run TileLang examples and develop your applications within the containerized environment. The Docker image comes with all necessary dependencies pre-installed, including CUDA toolkit, TVM, and TileLang itself.
**Example Usage:**
After accessing the container, you can run TileLang examples:
```bash
cd /home/tilelang/examples
python elementwise/test_example_elementwise.py
```
This Docker-based installation method provides a complete, isolated environment that works seamlessly on systems with compatible NVIDIA GPUs like the B200, ensuring optimal performance for TileLang applications.
(install-method-2)=
### Method 2: Install from Source (Using the Bundled TVM Submodule)
If you already have a compatible TVM installation, follow these steps:
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Use the `--recursive` flag to include necessary submodules.
2. **Configure Build Options**:
Create a build directory and specify your existing TVM path:
```bash
pip install . -v
```
(install-method-3)=
### Method 3: Install from Source (Using Your Own TVM Installation)
If you prefer to use the built-in TVM version, follow these instructions:
1. **Clone the Repository**:
```bash
git clone --recursive https://github.com/tile-ai/tilelang
cd tilelang
```
**Note**: Ensure the `--recursive` flag is included to fetch submodules.
2. **Configure Build Options**:
Copy the configuration file and enable the desired backends (e.g., LLVM and CUDA):
```bash
TVM_ROOT=<your-tvm-repo> pip install . -v
```
## Install with Nightly Version
For users who want access to the latest features and improvements before official releases, we provide nightly builds of **tile-lang**.
For users who want access to the latest features and improvements before official releases, we provide nightly builds of tilelang.
```bash
pip install tilelang -f https://tile-ai.github.io/whl/nightly/cu121/
......@@ -252,24 +207,28 @@ Set `NO_TOOLCHAIN_VERSION=ON` to disable this.
### Run-time environment variables
<!-- TODO: tvm -->
Please refer to the `env.py` file for a full list of supported run-time environment variables.
## Other Tips
## IDE Configs
### IDE Configs
Building tilelang locally will automatically `compile_commands.json` file in `build` dir.
Building tilelang locally will automatically generate a `compile_commands.json` file in `build` dir.
VSCode with clangd and [clangd extension](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) should be able to index that without extra configuration.
## Compile cache
### Compile Cache
`ccache` will be automatically used if found.
The default path of the compile cache is `~/.tilelang/cache`. `ccache` will be automatically used if found.
## Repairing wheels
### Repairing Wheels
If you plan to use your wheel in other environment,
it's recommend to use auditwheel (on Linux) or delocate (on Darwin)
it's recommended to use auditwheel (on Linux) or delocate (on Darwin)
to repair them.
## Faster rebuild for developers
(faster-rebuild-for-developers)=
### Faster Rebuild for Developers
`pip install` introduces extra [un]packaging and takes ~30 sec to complete,
even if no source change.
......@@ -278,8 +237,17 @@ Developers who needs to recompile frequently could use:
```bash
pip install -r requirements-dev.txt
# For first time compilation
pip install -e . -v --no-build-isolation
# Or manually compile with cmake/ninja. Remember to set PYTHONPATH properly.
mkdir build
cd build
cmake .. -G Ninja
ninja
# Rebuild when you change the cpp code
cd build; ninja
```
......
......@@ -2,10 +2,10 @@
[GitHub](https://github.com/tile-ai/tilelang)
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
Tile Language (tile-lang) is a concise domain-specific language designed to streamline
the development of high-performance GPU/CPU kernels (e.g., GEMM, Dequant GEMM, FlashAttention, LinearAttention).
By employing a Pythonic syntax with an underlying compiler infrastructure on top of TVM,
tile-lang allows developers to focus on productivity without sacrificing the
low-level optimizations necessary for state-of-the-art performance.
:::{toctree}
......@@ -24,6 +24,19 @@ get_started/targets
tutorials/debug_tools_for_tilelang
tutorials/auto_tuning
tutorials/logging
:::
:::{toctree}
:maxdepth: 1
:caption: PROGRAMMING GUIDES
programming_guides/overview
programming_guides/language_basics
programming_guides/instructions
programming_guides/control_flow
programming_guides/autotuning
programming_guides/type_system
:::
:::{toctree}
......@@ -33,6 +46,7 @@ tutorials/auto_tuning
deeplearning_operators/elementwise
deeplearning_operators/gemv
deeplearning_operators/matmul
deeplearning_operators/matmul_sparse
deeplearning_operators/deepseek_mla
:::
......@@ -42,6 +56,7 @@ deeplearning_operators/deepseek_mla
compiler_internals/letstmt_inline
compiler_internals/inject_fence_proxy
compiler_internals/tensor_checks
:::
:::{toctree}
......
# Autotuning
TileLang includes a built‑in autotuner that searches configuration spaces
for the best performing kernel, compiles candidates in parallel, validates
correctness, benchmarks them, and caches the best result for reuse.
This guide covers two workflows:
- Decorator‑based: `@tilelang.autotune(configs=...)` stacked on `@tilelang.jit`
- Programmatic: `AutoTuner.from_kernel(...).set_*().run()`
It also explains input tensor supply, validation, caching, and environment
variables that affect parallelism and cache behavior.
## 1) Decorator‑based Autotune
Use `@tilelang.autotune` above `@tilelang.jit` and expose tunable parameters as
function arguments with defaults. The autotuner overrides these parameters with
values from your config space.
```python
import tilelang
import tilelang.language as T
def matmul_configs(M, N, K):
# Example space — tailor to your target
tiles = [64, 128]
stages = [2, 3]
threads = [128, 256]
return [
dict(block_M=BM, block_N=BN, block_K=BK, num_stages=S, threads=TH)
for BM in tiles
for BN in tiles
for BK in [32, 64]
for S in stages
for TH in threads
]
@tilelang.autotune(configs=matmul_configs, warmup=25, rep=100, timeout=60)
@tilelang.jit(out_idx=[-1])
def matmul(M: int, N: int, K: int,
block_M: int = 128, block_N: int = 128, block_K: int = 32,
threads: int = 128, num_stages: int = 3,
dtype: str = 'float16', accum_dtype: str = 'float32'):
@T.prim_func
def kernel(A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_s = T.alloc_shared((block_M, block_K), dtype)
B_s = T.alloc_shared((block_K, block_N), dtype)
C_f = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, ko * block_K], A_s)
T.copy(B[ko * block_K, bx * block_N], B_s)
T.gemm(A_s, B_s, C_f)
T.copy(C_f, C[by * block_M, bx * block_N])
return kernel
# Usage
# Provide inputs via context (recommended for reproducibility across configs)
import torch
M = N = K = 1024
A = torch.randn(M, K, device='cuda', dtype=torch.float16)
B = torch.randn(K, N, device='cuda', dtype=torch.float16)
C = torch.empty(M, N, device='cuda', dtype=torch.float16)
from tilelang.autotuner import set_autotune_inputs
with set_autotune_inputs(A, B, C):
tuned_kernel = matmul(M, N, K) # compiles, tunes, returns best kernel
tuned_kernel(A, B, C) # run best kernel
```
Notes
- `configs` can be a list of dicts or a callable `(args...) -> list[dict]`. Each
dict’s keys must match the tunable function arguments (e.g., `block_M`).
- The decorator returns a callable that runs autotune once per argument tuple
and caches the resulting best kernel in‑process.
- For explicit input control during tuning, wrap the call with
`set_autotune_inputs(...)`. Otherwise, `supply_type` (below) is used.
## 2) Programmatic Autotune
Use the `AutoTuner` class to manage configs and arguments more explicitly.
```python
from tilelang.autotuner import AutoTuner
kernel_factory = matmul # the function above (already @tilelang.jit)
tuner = AutoTuner.from_kernel(kernel_factory(M, N, K), configs=matmul_configs(M, N, K))
tuner.set_profile_args(
warmup=25, rep=100, timeout=60,
supply_type=tilelang.TensorSupplyType.Auto, # or provide supply_prog/ref_prog
ref_prog=lambda A, B, C: torch.allclose(C, (A @ B).to(C.dtype), rtol=1e-2, atol=1e-2),
)
tuner.set_compile_args(
target='auto', # or 'cuda'/'hip'/'metal'
execution_backend='auto', # resolves per-target
out_idx=[-1], # which outputs to return if multiple
pass_configs={ # optional TVM passes/flags
# tilelang.PassConfigKey.EXAMPLE_KEY: value,
},
)
artifact = tuner.run() # compiles + runs + validates all configs
best_kernel = artifact.kernel # JITKernel
best_latency = artifact.latency
best_config = artifact.config
# Reuse best kernel
best_kernel(A, B, C)
```
### Example Gallery (in repo)
- examples/gdn/example_chunk_delta_h.py:101 — uses `@autotune` to sweep configs
- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py:451 — uses `@tilelang.autotune`
- examples/quickstart.py:84 — profiles a tuned kernel with `get_profiler`
- examples/hadamard_transform/example_hadamard.py:152 — profiler with custom warmup
- examples/dynamic_shape/example_dynamic.py:94 — profiler for dynamic shapes
- examples/gemm/example_gemm_persistent.py:135 — compare persistent vs non‑persistent
Click any path to open the code and compare patterns.
## Input Tensor Supply
The tuner needs inputs to compile and benchmark kernels. Provide them in one of
three ways (priority order):
1) Context manager (fixed inputs across configs)
```python
with set_autotune_inputs(A, B, C):
tuned = matmul(M, N, K)
```
2) Custom supplier program
```python
def supply_prog(signature):
# signature holds KernelParam objects describing shapes/dtypes
# Return a list of torch tensors matching the kernel’s arguments
return [A, B, C]
tuner.set_profile_args(supply_prog=supply_prog)
```
3) Built‑in generators via `supply_type`
- `TensorSupplyType.Auto` (default): heuristic per dtype (uniform ints / fp ranges)
- `Integer`, `Uniform`, `Normal`, `Randn`, `Zero`, `One`
Important
- Built‑in generators require static shapes; if your PrimFunc uses symbolic
dimensions (T.dyn), supply concrete inputs via (1) or (2).
- Float8 dtypes require PyTorch 2.1+ for `torch.float8_*` support.
## Correctness Checking and Tolerances
Use one of the following validation methods:
- `ref_prog`: Provide a reference program that receives the same inputs and
checks results. You can return a boolean or raise on mismatch.
- `manual_check_prog`: A callable that inspects outputs and raises on mismatch.
- `skip_check=True`: Skip correctness checks (faster, use with caution).
Control numeric drift via:
- `rtol` and `atol` (defaults 1e‑2)
- `max_mismatched_ratio` (default 1%)
## Configuration Spaces and Best Practices
What to tune
- Tile sizes: `block_M`, `block_N`, `block_K`
- Software pipelining: `num_stages`
- Threads per block: `threads` (or (x, y) tuple)
- Optional: dtype variants, epilogues, small scheduling knobs
Tips
- Start from a working baseline. Tune a small, meaningful space first.
- Respect hardware limits (shared memory bytes, registers per thread/block,
max threads per block). Eliminate impossible configs up‑front.
- Keep block sizes multiples of vector widths and warp sizes when relevant.
- Use `set_autotune_inputs` to ensure each config is measured on identical data.
- Record your best configs and bake them as defaults when stable.
## Parallel Compilation/Benchmarking and Timeouts
The tuner compiles configurations in parallel using a thread pool and benchmarks
them with a per‑config timeout. On CUDA, each worker sets the current device to
avoid context issues.
Notes
- `timeout` uses POSIX signals; on non‑Unix systems, it may not take effect.
- Logs are written to `autotuner.log` in the working directory.
## Caching
The autotuner caches best artifacts both in‑memory (per process) and on disk under
`$TILELANG_CACHE_DIR/autotuner`. The cache key includes:
- TileLang version, function source, closure free‑vars
- Config list, compile args, profile args
Disk cache contents (per key)
- Best config and latency: `best_config.json`, `latency.json`
- Kernel sources and library: `device_kernel.cu`, `host_kernel.cu`, `kernel_lib.so` (or `kernel.cubin`/`executable.so` depending on backend)
- Function and params: `function.pkl`, `params.pkl`
Control via env vars (tilelang.env)
- `TILELANG_CACHE_DIR` (default `~/.tilelang/cache`)
- `TILELANG_TMP_DIR` (default `$TILELANG_CACHE_DIR/tmp`)
- Disable all kernel caches: `TILELANG_DISABLE_CACHE=1`
- Disable autotune disk cache only: `TILELANG_AUTO_TUNING_DISABLE_CACHE=1`
CPU worker control
- `TILELANG_AUTO_TUNING_CPU_UTILITIES` (fraction, default 0.9)
- `TILELANG_AUTO_TUNING_CPU_COUNTS` (int, `-1` auto)
- `TILELANG_AUTO_TUNING_MAX_CPU_COUNT` (int, `-1` unlimited)
Backend notes
- NVRTC backend persists `.cubin` and a Python launcher.
- Torch/DLPack backend may not save artifacts to disk; in this case, only
in‑memory caching applies and a warning is logged.
## Alternative: Manual Sweeps with par_compile
If you prefer manual control, use `JITImpl.par_compile` to compile a batch of
configs and drive your own benchmarking:
```python
@tilelang.jit
def factory(M, N, K, block_M=128, block_N=128, block_K=32):
@T.prim_func
def k(A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16')):
...
return k
impl = factory # JITImpl
cfgs = [
dict(block_M=64, block_N=128, block_K=32),
dict(block_M=128, block_N=128, block_K=64),
]
kernels = impl.par_compile(cfgs, num_workers=4)
# Now benchmark kernels[i](A, B, C) yourself
```
## Recording and Reusing Best Configs
The programmatic path returns an `AutotuneResult` that can be saved and later
reloaded. This is useful for CI, multi‑host workflows, or shipping tuned configs.
```python
artifact = tuner.run() # AutotuneResult
# Save to disk
from pathlib import Path
save_dir = Path('out/best/matmul_1024')
artifact.save_to_disk(save_dir, verbose=True)
# Reload later
from tilelang.autotuner.param import AutotuneResult, CompileArgs
restored = AutotuneResult.load_from_disk(save_dir, CompileArgs())
best = restored.kernel
best(A, B, C)
```
Notes
- DLPack/Torch execution backend may not persist compiled binaries; in that
case, re‑compilation is needed on load or use a different backend.
- The directory contains human‑readable JSONs (best config/latency) and sources.
## Advanced: Config Space Callables
Derive config spaces from problem sizes to keep searches targeted and legal:
```python
def matmul_configs(M, N, K):
large = min(M, N, K) >= 1024
tiles = [128] if large else [64, 128]
for BM in tiles:
for BN in tiles:
for BK in [32, 64]:
for S in [2, 3]:
for TH in [128, 256]:
yield dict(block_M=BM, block_N=BN, block_K=BK,
num_stages=S, threads=TH)
```
## Device and Backend Selection
Tune compile‑time options explicitly:
- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target)
- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'`
- `pass_configs={...}` to toggle TileLang/TVM passes for experiments
On CUDA with multiple GPUs, the tuner sets the current device per worker thread
to avoid context mixups.
## Troubleshooting
- “No configurations to tune”: Ensure `configs` is a non‑empty list or callable.
- Timeouts: Increase `timeout`; ensure inputs fit device memory; verify that
your reference check isn’t the bottleneck.
- Dynamic shapes: Provide concrete inputs via `set_autotune_inputs` or a custom
`supply_prog`.
- Disk cache disabled: Check `TILELANG_AUTO_TUNING_DISABLE_CACHE` and backend.
# Control Flow
This guide covers the control‑flow primitives in TileLang and how they lower to
efficient GPU code. You will use these to structure loops, handle boundaries,
and express pipelined compute.
## Overview
- Conditionals: `if` / `elif` / `else`, ternary (`x if c else y`)
- Loops: `T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`
- While loops: `while` with a TIR condition
- Flow control: Python `break` / `continue`
- Safety: automatic OOB guards via the LegalizeSafeMemoryAccess pass
The examples assume `import tilelang.language as T`.
## Conditionals
Standard Python `if`/`elif`/`else` is supported inside `@T.prim_func` kernels.
Conditions should be TIR expressions (e.g., `i < N`). Python plain booleans are
treated as compile‑time constants and will be folded.
```python
for i in T.serial(N):
if i < N: # TIR condition
C[i] = A[i] + B[i]
else:
pass
# Ternary
x = (A[i] if i < N else 0)
```
Short‑circuit boolean ops are supported. For multi‑dimensional bounds, use
`T.any_of` / `T.all_of` for clarity:
```python
if T.all_of(i < M, j < N):
C[i, j] = A[i, j] + B[i, j]
```
Boundary handling note
- The LegalizeSafeMemoryAccess pass automatically inserts guards when an access
may be out‑of‑bounds, and elides them when proven safe. You can often omit
explicit `if` checks for simple edge handling, but keep them when you need
custom logic or clarity.
## Loops
### Serial
`T.serial` creates a plain for‑loop. Common forms:
```python
for i in T.serial(N):
... # 0..N-1
for i in T.serial(0, N, 2):
... # 0, 2, 4, ...
```
### Unroll
`T.unroll` requests loop unrolling for small trip counts.
```python
for k in T.unroll(K_TILE):
acc += a[k] * b[k]
```
Advanced: TileLang forwards unroll hints to TIR; factor/explicit knobs are
available for expert tuning.
### Parallel (elementwise)
`T.Parallel(ext0, ext1, ...)` builds nested loops that map well to elementwise
operations. The body receives all indices in one `for` header:
```python
for i, j in T.Parallel(M, N):
C[i, j] = A[i, j] + B[i, j]
```
Optional: `coalesced_width=` can hint memory coalescing for the innermost loop.
### Pipelined (software pipelining)
`T.Pipelined(iters, num_stages=...)` overlaps producer/consumer stages (e.g.,
Global→Shared copies with compute). This is the backbone of GEMM/attention
pipelines.
```python
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s) # stage: copy A tile
T.copy(B[ko * BK, bx * BN], B_s) # stage: copy B tile
T.gemm(A_s, B_s, C_f) # stage: compute
```
### Persistent (advanced)
`T.Persistent(domain, wave_size, index, group_size=...)` exposes persistent
thread‑block style looping. It is an advanced construct that TileLang lowers in
later passes and is typically used by specialized templates.
## While Loops
`while` is supported when the condition is a TIR expression. Avoid infinite
loops; TileLang will error if it detects a constant‑true condition.
```python
i = 0
while i < N:
...
if done:
break
i += 1
```
## Break and Continue
Use Python `break`/`continue` to exit or skip within `T.serial`/`T.unroll`/
`T.Parallel`/`while` loops. Keep the body clean after a `break`/`continue` for
readability; the compiler will ignore the dead path.
## Putting It Together: Residual Tile Handling
Below is a typical edge pattern for a 2D kernel. With LegalizeSafeMemoryAccess,
the explicit guard can be omitted when you don’t need a custom edge path.
```python
for i, j in T.Parallel(M, N):
gi = by * BM + i
gj = bx * BN + j
if T.all_of(gi < M, gj < N): # optional in many cases
C[gi, gj] = A[gi, gj] + B[gi, gj]
```
## Debugging Conditions
Use `T.print` to inspect values under predicates. For buffers, TileLang prints
from a single thread to avoid duplicate outputs.
```python
if i == 0:
T.print(C, msg='C tile:')
```
# Instructions
This page summarizes the core TileLang “instructions” available at the DSL
level, how they map to hardware concepts, and how to use them correctly.
## Quick Categories
- Data movement: `T.copy`, `T.c2d_im2col`, staging Global ↔ Shared ↔ Fragment
- Compute primitives: `T.gemm`/`T.gemm_sp`, elementwise math (`T.exp`, `T.max`),
reductions (`T.reduce_sum`, `T.cumsum`, warp reducers)
- Control helpers: `T.clear`/`T.fill`, `T.reshape`/`T.view`
- Diagnostics: `T.print`, `T.device_assert`
- Advanced: atomics, memory barriers, warp‑group ops
## Data Movement
Use `T.copy(src, dst, coalesced_width=None, disable_tma=False, eviction_policy=None)`
to move tiles between memory scopes. It accepts `tir.Buffer`, `BufferLoad`, or
`BufferRegion`; extents are inferred or broadcast when possible.
```python
# Global → Shared tiles (extents inferred from dst)
T.copy(A[by * BM, ko * BK], A_s)
T.copy(B[ko * BK, bx * BN], B_s)
# Fragment/Register → Global (store result)
T.copy(C_f, C[by * BM, bx * BN])
```
Semantics
- Extents are deduced from arguments; missing sides broadcast to the other’s rank.
- Access patterns are legalized and coalesced during lowering. Explicit
vectorization is not required in HL mode.
- Safety: the LegalizeSafeMemoryAccess pass inserts boundary guards when an
access may be out‑of‑bounds and drops them when proven safe.
Other helpers
- `T.c2d_im2col(img, col, ...)`: convenience for conv‑style transforms.
## Compute Primitives
GEMM and sparse GEMM
- `T.gemm(A_shared, B_shared, C_fragment)`: computes a tile GEMM using shared
inputs and a fragment accumulator; lowered to target‑specific tensor cores.
- `T.gemm_sp(...)`: 2:4 sparse tensor core variant (see examples and README).
Reductions and scans
- `T.reduce_sum`, `T.reduce_max`, `T.reduce_min`, `T.cumsum`, plus warp
reducers (`T.warp_reduce_sum`, etc.).
- Allocate and initialize accumulators via `T.alloc_fragment` + `T.clear` or
`T.fill`.
Elementwise math
- Most math ops mirror TVM TIR: `T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`,
`T.sigmoid`, etc. Compose freely inside loops.
Reshape/view (no copy)
- `T.reshape(buf, new_shape)` and `T.view(buf, shape=None, dtype=None)` create
new views that share storage, with shape/dtype checks enforced.
## Synchronization (HL usage)
In HL pipelines, you usually don’t need to write explicit barriers. Passes such
as PipelinePlanning/InjectSoftwarePipeline/InjectTmaBarrier orchestrate
producer/consumer ordering and thread synchronization behind the scenes.
If you need debugging or explicit checks:
- `T.device_assert(cond, msg='')` emits device‑side asserts on CUDA targets.
- `T.print(obj, msg='...')` prints scalars or buffers safely from one thread.
## Putting It Together: GEMM Tile
```python
@T.prim_func
def gemm(
A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16'),
):
with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
A_s = T.alloc_shared((BM, BK), 'float16')
B_s = T.alloc_shared((BK, BN), 'float16')
C_f = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s) # Global → Shared
T.copy(B[ko * BK, bx * BN], B_s)
T.gemm(A_s, B_s, C_f) # compute into fragment
T.copy(C_f, C[by * BM, bx * BN]) # store back
```
## Instruction Reference (Concise)
Below is a concise list of TileLang instructions grouped by category. For full
signatures, behaviors, constraints, and examples, refer to API Reference
(`autoapi/tilelang/index`).
Data movement
- `T.copy(src, dst, ...)`: Move tiles between Global/Shared/Fragment.
- `T.c2d_im2col(img, col, ...)`: 2D im2col transform for conv.
Memory allocation and descriptors
- `T.alloc_shared(shape, dtype, scope='shared.dyn')`: Allocate shared buffer.
- `T.alloc_fragment(shape, dtype, scope='local.fragment')`: Allocate fragment.
- `T.alloc_var(dtype, [init], scope='local.var')`: Scalar var buffer (1 elem).
- `T.alloc_barrier(arrive_count)`: Shared barrier buffer.
- `T.alloc_tmem(shape, dtype)`: Tensor memory (TMEM) buffer (Hopper+).
- `T.alloc_reducer(shape, dtype, op='sum', replication=None)`: Reducer buf.
- `T.alloc_descriptor(kind, dtype)`: Generic descriptor allocator.
- `T.alloc_wgmma_desc(dtype='uint64')`
- `T.alloc_tcgen05_smem_desc(dtype='uint64')`
- `T.alloc_tcgen05_instr_desc(dtype='uint32')`
- `T.empty(shape, dtype='float32')`: Declare function output tensors.
Compute primitives
- `T.gemm(A_s, B_s, C_f)`: Tile GEMM into fragment accumulator.
- `T.gemm_sp(...)`: Sparse (2:4) tensor core GEMM.
- Reductions: `T.reduce_sum/max/min/abssum/absmax`, bitwise `and/or/xor`.
- Scans: `T.cumsum`, finalize: `T.finalize_reducer`.
- Warp reducers: `T.warp_reduce_sum/max/min/bitand/bitor`.
- Elementwise math: TIR ops (`T.exp`, `T.log`, `T.max`, `T.min`, `T.rsqrt`, ...).
- Fast math: `T.__log/__log2/__log10/__exp/__exp2/__exp10/__sin/__cos/__tan`.
- IEEE math: `T.ieee_add/sub/mul/fmaf` (configurable rounding).
- Helpers: `T.clear(buf)`, `T.fill(buf, value)`.
- Views: `T.reshape(buf, shape)`, `T.view(buf, shape=None, dtype=None)`.
Diagnostics
- `T.print(obj, msg='')`: Print scalar/buffer from one thread.
- `T.device_assert(cond, msg='')`: Device-side assert (CUDA).
Logical helpers
- `T.any_of(a, b, ...)`, `T.all_of(a, b, ...)`: Multi-term predicates.
Annotation helpers
- `T.use_swizzle(panel_size=..., enable=True)`: Rasterization hint.
- `T.annotate_layout({...})`: Attach explicit layouts to buffers.
- `T.annotate_safe_value(var, ...)`: Safety/const hints.
- `T.annotate_l2_hit_ratio(buf, ratio)`: Cache behavior hint.
Atomics
- `T.atomic_add(dst, value, memory_order=None, return_prev=False, use_tma=False)`.
- `T.atomic_addx2(dst, value, return_prev=False)`; `T.atomic_addx4(...)`.
- `T.atomic_max(dst, value, memory_order=None, return_prev=False)`.
- `T.atomic_min(dst, value, memory_order=None, return_prev=False)`.
- `T.atomic_load(dst)`, `T.atomic_store(dst, value)`.
Custom intrinsics
- `T.dp4a(A, B, C)`: 4‑element dot‑product accumulate.
- `T.clamp(x, lo, hi)`: Clamp to [lo, hi].
- `T.loop_break()`: Break from current loop via intrinsic.
Barriers, TMA, warp‑group
- Barriers: `T.create_list_of_mbarrier(...)`, `T.get_mbarrier(i)`.
- Parity ops: `T.mbarrier_wait_parity(barrier, parity)`, `T.mbarrier_arrive(barrier)`.
- Expect tx: `T.mbarrier_expect_tx(...)`; sugar: `T.barrier_wait(id, parity=None)`.
- TMA: `T.create_tma_descriptor(...)`, `T.tma_load(...)`,
`T.tma_store_arrive(...)`, `T.tma_store_wait(...)`.
- Proxy/fences: `T.fence_proxy_async(...)`, `T.warpgroup_fence_operand(...)`.
- Warp‑group: `T.warpgroup_arrive()`, `T.warpgroup_commit_batch()`,
`T.warpgroup_wait(num_mma)`, `T.wait_wgmma(id)`.
Lane/warp index
- `T.get_lane_idx(warp_size=None)`: Lane id in warp.
- `T.get_warp_idx_sync(warp_size=None)`: Canonical warp id (sync).
- `T.get_warp_idx(warp_size=None)`: Canonical warp id (no sync).
- `T.get_warp_group_idx(warp_size=None, warps_per_group=None)`: Group id.
Register control
- `T.set_max_nreg(reg_count, is_inc)`, `T.inc_max_nreg(n)`, `T.dec_max_nreg(n)`.
- `T.annotate_producer_reg_dealloc(n=24)`, `T.annotate_consumer_reg_alloc(n=240)`.
- `T.no_set_max_nreg()`, `T.disable_warp_group_reg_alloc()`.
## Notes on Dtypes
Dtypes accept three equivalent forms:
- String: `'float32'`
- TileLang dtype: `T.float32`
- Framework dtype: `torch.float32`
All are normalized internally. See Type System for details.
# Language Basics
This page introduces the core TileLang (tile‑lang) DSL that you’ll use to write
high‑performance kernels. It focuses on how to define a kernel, express
iteration, move data across memory scopes, and run it with JIT.
The examples use the conventional aliases:
```python
import tilelang
import tilelang.language as T
from tilelang import jit
```
## 1. Defining a Kernel with `@T.prim_func`
TileLang kernels are TIR (TVM IR) functions produced by the `@T.prim_func`
decorator. Arguments are annotated with shapes and dtypes via `T.Tensor` or
`T.Buffer`.
Note on dtypes
- You can pass dtypes as a string (e.g., 'float32'), a TileLang dtype (e.g., `T.float32`),
or a framework dtype (e.g., `torch.float32`). TileLang normalizes all of these.
See Type System for details.
```python
@T.prim_func
def add_kernel(
A: T.Tensor((N,), dtype), # dtype could be 'float32' | T.float32 | torch.float32
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
... # kernel body
```
- Shapes may be concrete integers or symbolic. For symbolic, you can pass
Python ints through the outer `@jit` wrapper (shown below), or annotate with
`T.dyn` when you want a named symbolic dimension.
```python
# Named symbolic dimension (optional)
K = T.dyn['K']
@T.prim_func
def uses_dyn(A: T.Tensor((K,), 'float32')):
...
```
### Dynamic symbolic dimensions: two ways
TileLang supports two complementary ways to introduce symbolic (dynamic) dims:
- Type-level annotations via `T.dyn[...]` (recommended for function signatures)
- Use in `T.Tensor((T.dyn['K'], ...), dtype)` or bind once then reuse (as above).
- Inside the kernel body, prefer reading from the buffer’s shape, e.g. `M = A.shape[0]`.
- Term-level variables via `T.dynamic(name, dtype)`
- Creates a TIR `tir.Var` you can use directly in expressions/loops.
- Handy when you need to reference the dimension symbol in the body.
```python
# 1) Annotation-only symbol; read the bound size via shape
K = T.dyn['K'] # dtype defaults to int32
@T.prim_func
def foo(A: T.Tensor((K,), 'float32')):
N = A.shape[0]
for i in T.serial(N):
...
# 2) Explicit Var symbol usable in the body
K = T.dynamic('K', 'int32') # or T.dynamic('K') defaults to int32
@T.prim_func
def bar(A: T.Tensor((K,), 'float32')):
for i in T.serial(K):
...
```
Notes
- `T.symbolic(name, dtype)` is a deprecated alias of `T.dynamic`; prefer `T.dynamic`.
- Under `@jit`, concrete sizes come from the actual tensor arguments at the first call.
- Symbols in annotations do not need to be separate kernel arguments; TileLang binds them from argument shapes.
## 2. Launching Work with `T.Kernel`
`with T.Kernel(...)` declares a launch context and creates block/thread
bindings. For GPU backends, specify a grid and threads per block.
```python
with T.Kernel(grid_x, grid_y, threads=128) as (bx, by):
... # bx/by are blockIdx.x/y
```
You rarely need raw thread indices; most kernels use structured loops
(`T.serial`, `T.unroll`, `T.Parallel`, `T.Pipelined`) inside a `T.Kernel`.
## 3. Loops and Control Flow
Core loop constructs map to familiar hardware patterns:
- `T.serial(start, stop[, step])`: plain for‑loop
- `T.unroll(start, stop[, step])`: unrolled loop
- `T.Parallel(ext0, ext1, ...)`: nested parallel loops (elementwise‑friendly)
- `T.Pipelined(iters, num_stages=N)`: software pipelining for producer/consumer
```python
for i in T.serial(N):
...
for i, j in T.Parallel(M, N):
C[i, j] = A[i, j] + B[i, j]
for k in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
# overlap copy/compute across stages
...
```
Conditionals use standard Python `if`/`else`. Guard edges with predicates when
tile sizes do not divide problem sizes evenly.
## 4. Memory Scopes and Allocation
TileLang exposes key software‑managed scopes:
- Global: device memory (default for `T.Tensor` arguments)
- Shared: on‑chip, block‑visible (`T.alloc_shared(shape, dtype)`)
- Fragment and scalars: per‑thread fragments and scalar vars but in Shared View
(`T.alloc_fragment`, `T.alloc_var`)
```python
A_shared = T.alloc_shared((BM, BK), 'float16')
B_shared = T.alloc_shared((BK, BN), 'float16')
C_local = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_local) # zero accumulators
```
## 5. Moving Data: `T.copy`
Use `T.copy(src, dst)` to move tiles between scopes. It accepts buffers,
buffer regions, or buffer loads; extents are inferred or can be broadcast.
```python
# Global -> Shared (tile copy), extents inferred from dst
T.copy(A[by * BM, ko * BK], A_shared)
T.copy(B[ko * BK, bx * BN], B_shared)
# Fragment -> Global (store back)
T.copy(C_local, C[by * BM, bx * BN])
```
`T.copy` performs coalescing and scope‑specific lowering during compilation.
## 6. A Minimal End‑to‑End Example (Vector Add)
```python
import tilelang
import tilelang.language as T
from tilelang import jit
@jit # infers target from tensors at first call
def add(N: int, block: int = 256, dtype: str = 'float32'):
@T.prim_func
def add_kernel(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block), threads=block) as bx:
for i in T.Parallel(block):
gi = bx * block + i
# Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB
C[gi] = A[gi] + B[gi]
return add_kernel
# Host side (PyTorch shown; NumPy/DLPack also supported)
import torch
N = 1 << 20
A = torch.randn(N, device='cuda', dtype=torch.float32)
B = torch.randn(N, device='cuda', dtype=torch.float32)
C = torch.empty(N, device='cuda', dtype=torch.float32)
kernel = add(N)
kernel(A, B, C) # runs on GPU
torch.testing.assert_close(C, A + B)
```
Notes
- The `@jit` wrapper returns a callable kernel after the first compilation.
- You can pass compile‑time tunables (tile sizes, dtypes) through the outer
Python function and bake them into the generated TIR.
## 7. Tiled GEMM Skeleton
Below is a minimal pattern for a tiled GEMM using shared memory staging and a
fragment accumulator. It mirrors the quickstart style found in the repository.
```python
@T.prim_func
def gemm(
A: T.Tensor((M, K), 'float16'),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), 'float16'),
):
with T.Kernel(T.ceildiv(N, BN), T.ceildiv(M, BM), threads=128) as (bx, by):
A_s = T.alloc_shared((BM, BK), 'float16')
B_s = T.alloc_shared((BK, BN), 'float16')
C_f = T.alloc_fragment((BM, BN), 'float32')
T.clear(C_f)
for ko in T.Pipelined(T.ceildiv(K, BK), num_stages=3):
T.copy(A[by * BM, ko * BK], A_s)
T.copy(B[ko * BK, bx * BN], B_s)
T.gemm(A_s, B_s, C_f) # lowered to tensor‑core/ISA specific kernels
T.copy(C_f, C[by * BM, bx * BN])
```
## 8. Debugging and Printing
Use `T.print` inside a kernel for quick introspection. TileLang emits printing
from a single thread for shared/fragment scopes to avoid floods.
```python
T.print(C_f, msg='accumulator:')
T.print(A_s, msg='A tile:')
T.print(C[0], msg='C[0] = ')
```
## 9. Where to Go Next
- Control flow details: see Programming Guides → Control Flow
- Memory topics: see Programming Guides → (removed cache/layout); basics are covered inline
- Autotuning tile sizes and mappings: Programming Guides → Autotuning
- Operator examples (GEMM, GEMV, attention): see Deep Learning Operators
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment