{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "5e0deecc", "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", "import tilelang.language as T" ] }, { "cell_type": "markdown", "id": "1ca2c56d", "metadata": {}, "source": [ "# Tilelang Lazy JIT" ] }, { "cell_type": "markdown", "id": "156e7370", "metadata": {}, "source": [ "## Tensor Annotation" ] }, { "cell_type": "markdown", "id": "b070c109", "metadata": {}, "source": [ "Tilelang Lazy JIT combines the jit generation and invocation logic.\n", "\n", "The function signature syntax is similar to triton but with significant enhancements, most notably allowing Tensor annotations:\n", "\n", "For example, the code below annotates a 2D Tensor with T.Tensor[[int, int], T.float16]\n", "1. Each dimension is a compile-time constant; changing it triggers recompilation\n", "2. Its dtype must be T.float16\n", "\n", "DType can also be Any or None in addition to a concrete type\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "60bf8954", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm(\n", " A: T.Tensor[[int, int], T.float16],\n", " B: T.Tensor[[int, int], T.float16],\n", " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", " C = T.empty((M, N), out_dtype)\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", " T.clear(C_local)\n", " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", " T.copy(A[bx * block_M, k * block_K], A_shared)\n", " T.copy(B[k * block_K, by * block_N], B_shared)\n", " T.gemm(A_shared, B_shared, C_local)\n", " T.copy(C_local, C[bx * block_M, by * block_N])\n", " return C" ] }, { "cell_type": "markdown", "id": "28f868fe", "metadata": {}, "source": [ "Call the Tensor directly as an argument to trigger the full jit compile-and-run workflow:" ] }, { "cell_type": "code", "execution_count": 3, "id": "ee13394a", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "c6705091", "metadata": {}, "source": [ "Change the call-site arguments; if the compiler parameters differ, it recompiles:" ] }, { "cell_type": "code", "execution_count": 4, "id": "d8aab5b7", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, { "cell_type": "markdown", "id": "ce6b7391", "metadata": {}, "source": [ "You can also manually call compile helpers to build a kernel\n", "\n", "1. `ker.compile` compiles the kernel\n", "2. `ker.get_tir` retrieves the tir\n", "3. `ker.par_compile` compiles in parallel" ] }, { "cell_type": "code", "execution_count": 5, "id": "f3cf3a2d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" ] } ], "source": [ "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", "C = kernel(A, B)" ] }, { "cell_type": "markdown", "id": "921761b5", "metadata": {}, "source": [ "## More Tensor Annotation" ] }, { "cell_type": "markdown", "id": "4539e54e", "metadata": {}, "source": [ "### Separate the implementation with macros" ] }, { "cell_type": "markdown", "id": "ad96ba65", "metadata": {}, "source": [ "Next we'll implement a simple gemm in several ways. For convenience, first write a macro that captures the main gemm logic:" ] }, { "cell_type": "code", "execution_count": 6, "id": "171d4fe6", "metadata": {}, "outputs": [], "source": [ "@T.macro\n", "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", " T.clear(C_local)\n", " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", " T.copy(A[bx * block_M, k * block_K], A_shared)\n", " T.copy(B[k * block_K, by * block_N], B_shared)\n", " T.gemm(A_shared, B_shared, C_local)\n", " T.copy(C_local, C[bx * block_M, by * block_N])" ] }, { "cell_type": "markdown", "id": "446a1acd", "metadata": {}, "source": [ "### Mark dynamic shapes with T.dyn\n", "\n", "When some dimensions are dynamic, mark them with T.dyn. T.dyn can take a string argument to name the variable" ] }, { "cell_type": "code", "execution_count": null, "id": "6a38aa95", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", " return C" ] }, { "cell_type": "markdown", "id": "c60fd346", "metadata": {}, "source": [ "Inspect the lazy_jit function signature: parameters with a `$` suffix are compile-time constants that may vary, and those with `$dyn` are runtime variables" ] }, { "cell_type": "code", "execution_count": 8, "id": "c6992eb4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gemm_dyn_K.func.annot" ] }, { "cell_type": "code", "execution_count": 9, "id": "fe6cfdc8", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "2ee97bf7", "metadata": {}, "source": [ "### Use T.StridedTensor to annotate tensors with strides\n", "\n", "Annotation format: T.StridedTensor[Shape, Stride, DType]. Each Shape or Stride entry can be\n", "* int: compile-time constant\n", "* T.dyn: runtime value\n", "\n", "DType can be None or Any" ] }, { "cell_type": "code", "execution_count": 10, "id": "9dde1dae", "metadata": {}, "outputs": [], "source": [ "from typing import Any\n", "\n", "\n", "@tilelang.lazy_jit\n", "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] }, { "cell_type": "code", "execution_count": 11, "id": "dec2c0a7", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" ] }, { "cell_type": "markdown", "id": "f5fb20d6", "metadata": {}, "source": [ "## More Annotation" ] }, { "cell_type": "markdown", "id": "890df0a2", "metadata": {}, "source": [ "### Annotate tensors with T.ptr\n", "lazy_jit lets you declare a handle with T.ptr, but you must define its shape inside the function via T.match_buffer" ] }, { "cell_type": "code", "execution_count": 12, "id": "0fc17af6", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_ptr(\n", " A: T.ptr,\n", " B: T.ptr,\n", " M: int,\n", " N: int,\n", " K: int,\n", "):\n", " A = T.match_buffer(A, (M, K), T.float16)\n", " B = T.match_buffer(B, (K, N), T.float16)\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", " return C" ] }, { "cell_type": "code", "execution_count": 13, "id": "8e52a554", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "6b19ef90", "metadata": {}, "source": [ "### Use T.int32 to annotate runtime variables\n", "\n", "lazy_jit lets you define runtime variables with T.int32 or other types, enabling a fully dynamic gemm similar to triton" ] }, { "cell_type": "code", "execution_count": 14, "id": "c1e7598a", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_ptr_dyn(\n", " A: T.ptr,\n", " B: T.ptr,\n", " M: T.int32,\n", " N: T.int32,\n", " K: T.int32,\n", "):\n", " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", " return C" ] }, { "cell_type": "code", "execution_count": 15, "id": "9e9a4c88", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "39166cb4", "metadata": {}, "source": [ "## Compilation and parallel compilation" ] }, { "cell_type": "markdown", "id": "8c6fbe08", "metadata": {}, "source": [ "lazyjit and the original jit both support parallel compilation\n", "\n", "To avoid wasting memory with torch.tensor placeholders, use T.Tensor to create placeholders" ] }, { "cell_type": "code", "execution_count": 16, "id": "7222e57b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6d7f05cdfff412e9a527332438f7aa2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Elaborating: 0%| | 0/8 [00:00,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from itertools import product\n", "\n", "\n", "def get_configs():\n", " return [\n", " {\n", " \"A\": T.Tensor((1024, 1024), T.float32),\n", " \"B\": T.Tensor((1024, 1024), T.float32),\n", " \"block_M\": block_M,\n", " \"block_N\": block_N,\n", " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", "\n", "gemm.par_compile(get_configs())" ] }, { "cell_type": "markdown", "id": "5160d2cc", "metadata": {}, "source": [ "## More convenient macros" ] }, { "cell_type": "markdown", "id": "be44afc4", "metadata": {}, "source": [ "tilelang macros are now upgraded:\n", "\n", "1. Allow `T.Ref` as an annotation, similar to C++ pass-by-reference\n", "2. Allow returning multiple values\n", "3. Allow nesting and recursion" ] }, { "cell_type": "markdown", "id": "79575972", "metadata": {}, "source": [ "### Passing references with T.Ref\n", "\n", "The reference via T.Ref can target a var or a buffer element" ] }, { "cell_type": "code", "execution_count": null, "id": "90eaa6e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo(x_handle: T.handle):\n", " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", " # with T.block(\"root\"):\n", " bx = T.launch_thread(\"blockIdx.x\", 1)\n", " tx = T.launch_thread(\"threadIdx.x\", 128)\n", " ty = T.launch_thread(\"threadIdx.y\", 1)\n", " tz = T.launch_thread(\"threadIdx.z\", 1)\n", " with T.block(\"tilelang_root\"):\n", " T.reads()\n", " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", " x[1] = T.float32(1.0)\n", " _tmp: T.int32 = idx[0]\n", " x[_tmp] = T.float32(1.0)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", " x = 1 # noqa: F841\n", "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", " with T.Kernel(1) as _:\n", " # Supports constant indices\n", " macro_with_ref(x[1])\n", "\n", " # Also supports variable indices\n", " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", "\n", "foo" ] }, { "cell_type": "markdown", "id": "7bb447a2", "metadata": {}, "source": [ "### Pass as arguments\n", "\n", "You can pass macros as parameters" ] }, { "cell_type": "code", "execution_count": 18, "id": "dc7bb779", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def element_wise(\n", " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", " for i in T.Parallel(block_N):\n", " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", "\n", "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" ] }, { "cell_type": "code", "execution_count": 19, "id": "a89fdb44", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" ] }, { "cell_type": "markdown", "id": "ef6e403a", "metadata": {}, "source": [ "### Macro recursion\n", "\n", "Macro can be recursive, even if it's rarely needed, as long as the termination condition is known at compile time" ] }, { "cell_type": "code", "execution_count": 20, "id": "7703cab5", "metadata": {}, "outputs": [], "source": [ "@T.macro\n", "def n31(x, var: T.Ref):\n", " if x == 1:\n", " pass\n", " elif x % 2 == 0:\n", " var = var // 2\n", " n31(x // 2, var)\n", " else:\n", " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" ] }, { "cell_type": "code", "execution_count": 21, "id": "542ddd4e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([18], device='cuda:0', dtype=torch.int32)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] }, { "cell_type": "markdown", "id": "dc30c2d2", "metadata": {}, "source": [ "### Macro returning multiple values" ] }, { "cell_type": "code", "execution_count": null, "id": "d5a2388f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# import tilelang.language as T\n", "\n", "@T.prim_func\n", "def foo():\n", " # with T.block(\"root\"):\n", " x = T.launch_thread(\"blockIdx.x\", 32)\n", " tx = T.launch_thread(\"threadIdx.x\", 128)\n", " ty = T.launch_thread(\"threadIdx.y\", 1)\n", " tz = T.launch_thread(\"threadIdx.z\", 1)\n", " with T.block(\"tilelang_root\"):\n", " T.reads()\n", " T.writes()\n", " s: T.int32 = T.sin(x)\n", " c: T.int32 = T.cos(x)\n", " a: T.int32 = s + c\n", " b: T.int32 = s - c\n", " T.evaluate(0)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@T.macro\n", "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", " a = s + c # noqa: F841\n", " b = s - c # noqa: F841\n", "\n", "\n", "foo" ] } ], "metadata": { "kernelspec": { "display_name": "tilelang-dev_0", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }