{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "5e0deecc", "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "\n", "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", "import tilelang\n", "import torch\n", "import tilelang.language as T" ] }, { "cell_type": "markdown", "id": "1ca2c56d", "metadata": {}, "source": [ "# Tilelang Lazy JIT" ] }, { "cell_type": "markdown", "id": "156e7370", "metadata": {}, "source": [ "## Tensor Annotation" ] }, { "cell_type": "markdown", "id": "b070c109", "metadata": {}, "source": [ "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n", "\n", "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n", "\n", "例如,下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor\n", "1. 它的每个维度都是编译期常量,如果改变,会触发重新编译\n", "2. 它的类型必须是 T.float16\n", "\n", "DType 除了写确定的外,还可以写 Any 或者 None" ] }, { "cell_type": "code", "execution_count": 2, "id": "60bf8954", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm(\n", " A: T.Tensor[[int, int], T.float16],\n", " B: T.Tensor[[int, int], T.float16],\n", " out_dtype: T.dtype = T.float32,\n", " block_M: int = 128,\n", " block_N: int = 128,\n", " block_K: int = 32,\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", " C = T.empty((M, N), out_dtype)\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", " T.clear(C_local)\n", " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", " T.copy(A[bx * block_M, k * block_K], A_shared)\n", " T.copy(B[k * block_K, by * block_N], B_shared)\n", " T.gemm(A_shared, B_shared, C_local)\n", " T.copy(C_local, C[bx * block_M, by * block_N])\n", " return C" ] }, { "cell_type": "markdown", "id": "28f868fe", "metadata": {}, "source": [ "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:" ] }, { "cell_type": "code", "execution_count": 3, "id": "ee13394a", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B)\n", "\n", "# check output is correct\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "c6705091", "metadata": {}, "source": [ "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:" ] }, { "cell_type": "code", "execution_count": 4, "id": "d8aab5b7", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", "C = gemm(A, B, block_M=64, block_N=64)" ] }, { "cell_type": "markdown", "id": "ce6b7391", "metadata": {}, "source": [ "你也可以手动调用 compile 函数编译 kernel\n", "\n", "1. `ker.compile` 编译 kernel\n", "2. `ker.get_tir` 获取 tir\n", "3. `ker.par_compile` 并行编译" ] }, { "cell_type": "code", "execution_count": 5, "id": "f3cf3a2d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n" ] } ], "source": [ "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", "C = kernel(A, B)" ] }, { "cell_type": "markdown", "id": "921761b5", "metadata": {}, "source": [ "## More Tensor Annotation" ] }, { "cell_type": "markdown", "id": "4539e54e", "metadata": {}, "source": [ "### 用 macro 来分离实现" ] }, { "cell_type": "markdown", "id": "ad96ba65", "metadata": {}, "source": [ "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:" ] }, { "cell_type": "code", "execution_count": 6, "id": "171d4fe6", "metadata": {}, "outputs": [], "source": [ "@T.macro\n", "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", " T.clear(C_local)\n", " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", " T.copy(A[bx * block_M, k * block_K], A_shared)\n", " T.copy(B[k * block_K, by * block_N], B_shared)\n", " T.gemm(A_shared, B_shared, C_local)\n", " T.copy(C_local, C[bx * block_M, by * block_N])" ] }, { "cell_type": "markdown", "id": "446a1acd", "metadata": {}, "source": [ "### 用 T.dyn 标记动态 Shape\n", "\n", "当某些维度是动态的的时候,可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数,表示变量的名字" ] }, { "cell_type": "code", "execution_count": null, "id": "6a38aa95", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_dyn_K(\n", " A: T.Tensor[[int, T.dyn[\"K\"]], T.float16], # noqa: F821\n", " B: T.Tensor[[T.dyn[\"K\"], int], T.float16], # noqa: F821\n", "):\n", " M, K = A.shape\n", " K, N = B.shape\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", " return C" ] }, { "cell_type": "markdown", "id": "c60fd346", "metadata": {}, "source": [ "查看 lazy_jit 的函数签名,其中带有后缀`$` 的是不确定的编译期常量,带有 `$dyn` 的是运行时的变量" ] }, { "cell_type": "code", "execution_count": 8, "id": "c6992eb4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n", " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gemm_dyn_K.func.annot" ] }, { "cell_type": "code", "execution_count": 9, "id": "fe6cfdc8", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_dyn_K(A, B)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "2ee97bf7", "metadata": {}, "source": [ "### 用 T.StridedTensor 标记带 stride 的 Tensor\n", "\n", "标记方法:T.StridedTensor[Shape, Stride, DType],每个 Shape 或 Stride 可以写\n", "* int: 表示编译期常量\n", "* T.dyn:表示运行时常量\n", "\n", "DType 可以写 None 或 Any" ] }, { "cell_type": "code", "execution_count": 10, "id": "9dde1dae", "metadata": {}, "outputs": [], "source": [ "from typing import Any\n", "\n", "\n", "@tilelang.lazy_jit\n", "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n", " M, N = A.shape\n", " B = T.empty((M, N), A.dtype)\n", " block_M = 128\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", " T.copy(\n", " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", " )\n", " return B" ] }, { "cell_type": "code", "execution_count": 11, "id": "dec2c0a7", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 1024, device=\"cuda\")\n", "B = as_contingious(A[::2, ::2])\n", "B_ref = A[::2, ::2].contiguous()\n", "torch.testing.assert_close(B, B_ref)" ] }, { "cell_type": "markdown", "id": "f5fb20d6", "metadata": {}, "source": [ "## More Annotation" ] }, { "cell_type": "markdown", "id": "890df0a2", "metadata": {}, "source": [ "### 用 T.ptr 标注 Tensor\n", "lazy_jit 允许你用 T.ptr 来声明一个 handle,但必须在函数内用 T.match_buffer 给它定义 shape" ] }, { "cell_type": "code", "execution_count": 12, "id": "0fc17af6", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_ptr(\n", " A: T.ptr,\n", " B: T.ptr,\n", " M: int,\n", " N: int,\n", " K: int,\n", "):\n", " A = T.match_buffer(A, (M, K), T.float16)\n", " B = T.match_buffer(B, (K, N), T.float16)\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", " return C" ] }, { "cell_type": "code", "execution_count": 13, "id": "8e52a554", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "6b19ef90", "metadata": {}, "source": [ "### 用 T.int32 标注运行时变量\n", "\n", "lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量,这样,你可以写一个完全动态的 gemm,这和 triton 非常相似" ] }, { "cell_type": "code", "execution_count": 14, "id": "c1e7598a", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def gemm_ptr_dyn(\n", " A: T.ptr,\n", " B: T.ptr,\n", " M: T.int32,\n", " N: T.int32,\n", " K: T.int32,\n", "):\n", " A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n", " B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n", " C = T.empty((M, N), T.float32)\n", " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", " return C" ] }, { "cell_type": "code", "execution_count": 15, "id": "9e9a4c88", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", "C_ref = (A @ B).float()\n", "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" ] }, { "cell_type": "markdown", "id": "39166cb4", "metadata": {}, "source": [ "## 编译与并行编译" ] }, { "cell_type": "markdown", "id": "8c6fbe08", "metadata": {}, "source": [ "lazyjit 和原来的 jit 都支持并行编译\n", "\n", "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder" ] }, { "cell_type": "code", "execution_count": 16, "id": "7222e57b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6d7f05cdfff412e9a527332438f7aa2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Elaborating: 0%| | 0/8 [00:00,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from itertools import product\n", "\n", "\n", "def get_configs():\n", " return [\n", " {\n", " \"A\": T.Tensor((1024, 1024), T.float32),\n", " \"B\": T.Tensor((1024, 1024), T.float32),\n", " \"block_M\": block_M,\n", " \"block_N\": block_N,\n", " \"block_K\": block_K,\n", " }\n", " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", " ]\n", "\n", "\n", "gemm.par_compile(get_configs())" ] }, { "cell_type": "markdown", "id": "5160d2cc", "metadata": {}, "source": [ "## 更便利的 Macro" ] }, { "cell_type": "markdown", "id": "be44afc4", "metadata": {}, "source": [ "tilelang 的 macro 现在已经升级:\n", "\n", "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n", "2. 允许返回多个值\n", "3. 允许嵌套,递归" ] }, { "cell_type": "markdown", "id": "79575972", "metadata": {}, "source": [ "### T.Ref 传递引用\n", "\n", "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引" ] }, { "cell_type": "code", "execution_count": null, "id": "90eaa6e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# from tvm.script import tir as T\n", "\n", "@T.prim_func\n", "def foo(x_handle: T.handle):\n", " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", " # with T.block(\"root\"):\n", " bx = T.launch_thread(\"blockIdx.x\", 1)\n", " tx = T.launch_thread(\"threadIdx.x\", 128)\n", " ty = T.launch_thread(\"threadIdx.y\", 1)\n", " tz = T.launch_thread(\"threadIdx.z\", 1)\n", " with T.block(\"tilelang_root\"):\n", " T.reads()\n", " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", " x[1] = T.float32(1.0)\n", " _tmp: T.int32 = idx[0]\n", " x[_tmp] = T.float32(1.0)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@T.macro\n", "def macro_with_ref(x: T.Ref):\n", " x = 1 # noqa: F841\n", "\n", "\n", "@T.prim_func\n", "def foo(x: T.Tensor((2,))):\n", " with T.Kernel(1) as _:\n", " # 支持常量 index\n", " macro_with_ref(x[1])\n", "\n", " # 也支持变量 index\n", " idx = T.alloc_var(T.int32, 0)\n", " macro_with_ref(x[idx])\n", "\n", "\n", "foo" ] }, { "cell_type": "markdown", "id": "7bb447a2", "metadata": {}, "source": [ "### 当作参数传递\n", "\n", "你可以把 macro 当做参数传递" ] }, { "cell_type": "code", "execution_count": 18, "id": "dc7bb779", "metadata": {}, "outputs": [], "source": [ "@tilelang.lazy_jit\n", "def element_wise(\n", " A: T.Tensor[[T.dyn], Any],\n", " fn,\n", "):\n", " (N,) = A.shape\n", " B = T.empty((N,), dtype=A.dtype)\n", " block_N = 128\n", " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", " for i in T.Parallel(block_N):\n", " idx = bx * block_N + i\n", " B[idx] = fn(A[idx])\n", " return B\n", "\n", "\n", "@T.macro\n", "def add_one(x):\n", " return x + 1" ] }, { "cell_type": "code", "execution_count": 19, "id": "a89fdb44", "metadata": {}, "outputs": [], "source": [ "A = torch.randn(1024, device=\"cuda\")\n", "B = element_wise(A, add_one)\n", "B_ref = A + 1\n", "torch.testing.assert_close(B, B_ref)" ] }, { "cell_type": "markdown", "id": "ef6e403a", "metadata": {}, "source": [ "### Macro 递归\n", "\n", "虽然不知道有没有这种需求,但 macro 是可以递归的,但要求终止条件编译期间确定" ] }, { "cell_type": "code", "execution_count": 20, "id": "7703cab5", "metadata": {}, "outputs": [], "source": [ "@T.macro\n", "def n31(x, var: T.Ref):\n", " if x == 1:\n", " pass\n", " elif x % 2 == 0:\n", " var = var // 2\n", " n31(x // 2, var)\n", " else:\n", " var = var * 3 + 1\n", " n31(x * 3 + 1, var)\n", "\n", "\n", "@tilelang.lazy_jit\n", "def foo(A: T.Tensor[[1], T.int32], n: int):\n", " with T.Kernel(1) as _:\n", " n31(n, A[0])" ] }, { "cell_type": "code", "execution_count": 21, "id": "542ddd4e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([18], device='cuda:0', dtype=torch.int32)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", "foo(A, 5)\n", "A" ] }, { "cell_type": "markdown", "id": "dc30c2d2", "metadata": {}, "source": [ "### Macro 返回多个值" ] }, { "cell_type": "code", "execution_count": null, "id": "d5a2388f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "# from tvm.script import tir as T\n", "\n", "@T.prim_func\n", "def foo():\n", " # with T.block(\"root\"):\n", " x = T.launch_thread(\"blockIdx.x\", 32)\n", " tx = T.launch_thread(\"threadIdx.x\", 128)\n", " ty = T.launch_thread(\"threadIdx.y\", 1)\n", " tz = T.launch_thread(\"threadIdx.z\", 1)\n", " with T.block(\"tilelang_root\"):\n", " T.reads()\n", " T.writes()\n", " s: T.int32 = T.sin(x)\n", " c: T.int32 = T.cos(x)\n", " a: T.int32 = s + c\n", " b: T.int32 = s - c\n", " T.evaluate(0)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@T.macro\n", "def sincos(x):\n", " return T.sin(x), T.cos(x)\n", "\n", "\n", "@T.prim_func\n", "def foo():\n", " with T.Kernel(32) as x:\n", " s, c = sincos(x)\n", " a = s + c # noqa: F841\n", " b = s - c # noqa: F841\n", "\n", "\n", "foo" ] } ], "metadata": { "kernelspec": { "display_name": "tilelang-dev_0", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }