Commit 7de9ffb7 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

instruction update (#10)

parent a1a3e2e6
......@@ -8,6 +8,9 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!
## Tested Devices
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A600; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
......@@ -68,80 +71,7 @@ We currently provide three ways to install **tile-lang** from source:
In this section, you’ll learn how to write and execute a straightforward GEMM (matrix multiplication) kernel using tile-lang, followed by techniques for layout optimizations, pipelining, and L2-cache–friendly swizzling.
### Basic GEMM Example
Below is a minimal example showing how to define and run a matrix multiplication kernel in tile-lang. This serves as a gentle introduction to the language’s key concepts.
```python
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Define a GPU kernel launch configuration:
# - Grid dimension: (ceildiv(N, block_N), ceildiv(M, block_M))
# - Threads per block: 128
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate on-chip memory (shared and fragment buffers)
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Initialize the accumulation buffer
T.clear(C_local)
# Primary compute loop, with pipelining across chunks of size block_K
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy a tile of A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Copy a tile of B into shared memory
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers into C_local
T.gemm(A_shared, B_shared, C_local)
# Write the accumulated result from local memory back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(1024, 1024, 1024, 128, 128, 32)
rt_mod, params = tilelang.lower(func)
# 2. Create a Profiler object for running performance and correctness tests
profiler = Profiler(rt_mod, params, result_idx=[2])
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = profiler(a, b)
# Reference multiplication using PyTorch
ref_c = a @ b
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = rt_mod.imported_modules[0].get_source()
print("Generated CUDA kernel:\n", cuda_source)
```
### Enhanced Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
### GEMM Example with Annotations (Layout, L2 Cache Swizzling, and Pipelining, etc.)
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
......@@ -167,13 +97,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout
# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Enable rasterization for better L2 cache locality
# Enable rasterization for better L2 cache locality (Optional)
T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
......@@ -181,6 +112,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, k * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B
......@@ -188,6 +120,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j]
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)
# Copy result back to global memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment