Unverified Commit 058a670b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Doc] Optimize the quickstart guide for clarity and not just for CUDA (#858)

* Refactor matmul example to include ReLU activation and update batch size in benchmark script

* lint fix
parent bd168654
......@@ -123,35 +123,24 @@ Below is an example that demonstrates more advanced features: layout annotation,
```python
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,)
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
......@@ -164,53 +153,58 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B
for k, j in T.Parallel(block_K, block_N):
B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return matmul_relu_kernel
# 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32)
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the JIT-compiled function
c = jit_kernel(a, b)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with the profiler
profiler = jit_kernel.get_profiler()
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
......
import tilelang
import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
......@@ -25,13 +20,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically
# T.annotate_layout({
# A_shared: make_swizzle_layout(A_shared),
# B_shared: make_swizzle_layout(B_shared),
# })
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
......@@ -41,8 +29,6 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
......@@ -52,10 +38,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
return matmul_relu_kernel
M = 1024 # M = T.symbolic("m") if you want to use dynamic shape
......@@ -66,13 +56,7 @@ block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
......@@ -80,13 +64,14 @@ import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
......@@ -97,7 +82,7 @@ print("Kernel output matches PyTorch reference.")
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment