Unverified Commit 0921328d authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Language] Tilelang LazyJIT Experimental Version (#1337)



* initial step

* modify builder

* scratch version of new frontend

* write some tests

* add many tests

* add typing stub for tir.ir

* remove idents

* minor update

* minor update

* First version of jitv2 (renamed to LazyJIT)

* fix pre-commit error

* minor fix

* fix lint error

* fix lint error

* Fix conditional check for PrimFunc instance

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 8d019eb9
......@@ -35,11 +35,7 @@ repos:
rev: v21.1.6 # sync with requirements-lint.txt
hooks:
- id: clang-format
exclude: |
(?ix)(
^.+\.(cu|cuh)$|
^.+\.json$
)
types_or: [c++, c]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.7 # sync with requirements-lint.txt
hooks:
......@@ -66,4 +62,4 @@ repos:
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
^.+\.svg$|
^.*\brequirements\b.*\.txt$
)
)
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "5e0deecc",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
"import tilelang.language as T"
]
},
{
"cell_type": "markdown",
"id": "1ca2c56d",
"metadata": {},
"source": [
"# Tilelang Lazy JIT"
]
},
{
"cell_type": "markdown",
"id": "156e7370",
"metadata": {},
"source": [
"## Tensor Annotation"
]
},
{
"cell_type": "markdown",
"id": "b070c109",
"metadata": {},
"source": [
"Tilelang Lazy JIT combines the jit generation and invocation logic.\n",
"\n",
"The function signature syntax is similar to triton but with significant enhancements, most notably allowing Tensor annotations:\n",
"\n",
"For example, the code below annotates a 2D Tensor with T.Tensor[[int, int], T.float16]\n",
"1. Each dimension is a compile-time constant; changing it triggers recompilation\n",
"2. Its dtype must be T.float16\n",
"\n",
"DType can also be Any or None in addition to a concrete type\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "60bf8954",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm(\n",
" A: T.Tensor[[int, int], T.float16],\n",
" B: T.Tensor[[int, int], T.float16],\n",
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" C = T.empty((M, N), out_dtype)\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
" B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
" C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n",
" T.clear(C_local)\n",
" for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
" T.copy(A[bx * block_M, k * block_K], A_shared)\n",
" T.copy(B[k * block_K, by * block_N], B_shared)\n",
" T.gemm(A_shared, B_shared, C_local)\n",
" T.copy(C_local, C[bx * block_M, by * block_N])\n",
" return C"
]
},
{
"cell_type": "markdown",
"id": "28f868fe",
"metadata": {},
"source": [
"Call the Tensor directly as an argument to trigger the full jit compile-and-run workflow:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ee13394a",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "c6705091",
"metadata": {},
"source": [
"Change the call-site arguments; if the compiler parameters differ, it recompiles:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d8aab5b7",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
{
"cell_type": "markdown",
"id": "ce6b7391",
"metadata": {},
"source": [
"You can also manually call compile helpers to build a kernel\n",
"\n",
"1. `ker.compile` compiles the kernel\n",
"2. `ker.get_tir` retrieves the tir\n",
"3. `ker.par_compile` compiles in parallel"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f3cf3a2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n"
]
}
],
"source": [
"kernel = gemm.compile(A, B, block_M=64, block_N=64)\n",
"C = kernel(A, B)"
]
},
{
"cell_type": "markdown",
"id": "921761b5",
"metadata": {},
"source": [
"## More Tensor Annotation"
]
},
{
"cell_type": "markdown",
"id": "4539e54e",
"metadata": {},
"source": [
"### Separate the implementation with macros"
]
},
{
"cell_type": "markdown",
"id": "ad96ba65",
"metadata": {},
"source": [
"Next we'll implement a simple gemm in several ways. For convenience, first write a macro that captures the main gemm logic:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "171d4fe6",
"metadata": {},
"outputs": [],
"source": [
"@T.macro\n",
"def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
" B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
" C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n",
" T.clear(C_local)\n",
" for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
" T.copy(A[bx * block_M, k * block_K], A_shared)\n",
" T.copy(B[k * block_K, by * block_N], B_shared)\n",
" T.gemm(A_shared, B_shared, C_local)\n",
" T.copy(C_local, C[bx * block_M, by * block_N])"
]
},
{
"cell_type": "markdown",
"id": "446a1acd",
"metadata": {},
"source": [
"### Mark dynamic shapes with T.dyn\n",
"\n",
"When some dimensions are dynamic, mark them with T.dyn. T.dyn can take a string argument to name the variable"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a38aa95",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n",
" return C"
]
},
{
"cell_type": "markdown",
"id": "c60fd346",
"metadata": {},
"source": [
"Inspect the lazy_jit function signature: parameters with a `$` suffix are compile-time constants that may vary, and those with `$dyn` are runtime variables"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c6992eb4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n",
" 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gemm_dyn_K.func.annot"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fe6cfdc8",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "2ee97bf7",
"metadata": {},
"source": [
"### Use T.StridedTensor to annotate tensors with strides\n",
"\n",
"Annotation format: T.StridedTensor[Shape, Stride, DType]. Each Shape or Stride entry can be\n",
"* int: compile-time constant\n",
"* T.dyn: runtime value\n",
"\n",
"DType can be None or Any"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9dde1dae",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" )\n",
" return B"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "dec2c0a7",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device='cuda')\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
]
},
{
"cell_type": "markdown",
"id": "f5fb20d6",
"metadata": {},
"source": [
"## More Annotation"
]
},
{
"cell_type": "markdown",
"id": "890df0a2",
"metadata": {},
"source": [
"### Annotate tensors with T.ptr\n",
"lazy_jit lets you declare a handle with T.ptr, but you must define its shape inside the function via T.match_buffer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0fc17af6",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_ptr(\n",
" A: T.ptr,\n",
" B: T.ptr,\n",
" M: int,\n",
" N: int,\n",
" K: int,\n",
"):\n",
" A = T.match_buffer(A, (M, K), T.float16)\n",
" B = T.match_buffer(B, (K, N), T.float16)\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
" return C"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8e52a554",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "6b19ef90",
"metadata": {},
"source": [
"### Use T.int32 to annotate runtime variables\n",
"\n",
"lazy_jit lets you define runtime variables with T.int32 or other types, enabling a fully dynamic gemm similar to triton"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c1e7598a",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_ptr_dyn(\n",
" A: T.ptr,\n",
" B: T.ptr,\n",
" M: T.int32,\n",
" N: T.int32,\n",
" K: T.int32,\n",
"):\n",
" A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n",
" B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
" return C"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9e9a4c88",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "39166cb4",
"metadata": {},
"source": [
"## Compilation and parallel compilation"
]
},
{
"cell_type": "markdown",
"id": "8c6fbe08",
"metadata": {},
"source": [
"lazyjit and the original jit both support parallel compilation\n",
"\n",
"To avoid wasting memory with torch.tensor placeholders, use T.Tensor to create placeholders"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7222e57b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6d7f05cdfff412e9a527332438f7aa2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Elaborating: 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "14836065a21b41ae8fc34e8763ae49fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Parallel Compiling: 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from itertools import product\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n",
" 'block_N': block_N,\n",
" 'block_K': block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
{
"cell_type": "markdown",
"id": "5160d2cc",
"metadata": {},
"source": [
"## More convenient macros"
]
},
{
"cell_type": "markdown",
"id": "be44afc4",
"metadata": {},
"source": [
"tilelang macros are now upgraded:\n",
"\n",
"1. Allow `T.Ref` as an annotation, similar to C++ pass-by-reference\n",
"2. Allow returning multiple values\n",
"3. Allow nesting and recursion"
]
},
{
"cell_type": "markdown",
"id": "79575972",
"metadata": {},
"source": [
"### Passing references with T.Ref\n",
"\n",
"The reference via T.Ref can target a var or a buffer element"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90eaa6e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"\n",
"@T.prim_func\n",
"def foo(x_handle: T.handle):\n",
" x = T.match_buffer(x_handle, (2,), strides=(1,))\n",
" # with T.block(\"root\"):\n",
" bx = T.launch_thread(\"blockIdx.x\", 1)\n",
" tx = T.launch_thread(\"threadIdx.x\", 128)\n",
" ty = T.launch_thread(\"threadIdx.y\", 1)\n",
" tz = T.launch_thread(\"threadIdx.z\", 1)\n",
" with T.block(\"tilelang_root\"):\n",
" T.reads()\n",
" idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n",
" T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n",
" T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n",
" idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n",
" x[1] = T.float32(1.0)\n",
" _tmp: T.int32 = idx[0]\n",
" x[_tmp] = T.float32(1.0)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@T.macro\n",
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
" # Supports constant indices\n",
" macro_with_ref(x[1])\n",
"\n",
" # Also supports variable indices\n",
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"foo"
]
},
{
"cell_type": "markdown",
"id": "7bb447a2",
"metadata": {},
"source": [
"### Pass as arguments\n",
"\n",
"You can pass macros as parameters"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "dc7bb779",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def element_wise(\n",
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
" for i in T.Parallel(block_N):\n",
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a89fdb44",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device='cuda')\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
]
},
{
"cell_type": "markdown",
"id": "ef6e403a",
"metadata": {},
"source": [
"### Macro recursion\n",
"\n",
"Macro can be recursive, even if it's rarely needed, as long as the termination condition is known at compile time"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7703cab5",
"metadata": {},
"outputs": [],
"source": [
"@T.macro\n",
"def n31(x, var: T.Ref):\n",
" if x == 1:\n",
" pass\n",
" elif x % 2 == 0:\n",
" var = var // 2\n",
" n31(x // 2, var)\n",
" else:\n",
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "542ddd4e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([18], device='cuda:0', dtype=torch.int32)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n",
"foo(A, 5)\n",
"A"
]
},
{
"cell_type": "markdown",
"id": "dc30c2d2",
"metadata": {},
"source": [
"### Macro returning multiple values"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5a2388f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" # with T.block(\"root\"):\n",
" x = T.launch_thread(\"blockIdx.x\", 32)\n",
" tx = T.launch_thread(\"threadIdx.x\", 128)\n",
" ty = T.launch_thread(\"threadIdx.y\", 1)\n",
" tz = T.launch_thread(\"threadIdx.z\", 1)\n",
" with T.block(\"tilelang_root\"):\n",
" T.reads()\n",
" T.writes()\n",
" s: T.int32 = T.sin(x)\n",
" c: T.int32 = T.cos(x)\n",
" a: T.int32 = s + c\n",
" b: T.int32 = s - c\n",
" T.evaluate(0)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@T.macro\n",
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"foo"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tilelang-dev_0",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "5e0deecc",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"from pathlib import Path\n",
"sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
"import tilelang\n",
"import torch\n",
"import tilelang.language as T"
]
},
{
"cell_type": "markdown",
"id": "1ca2c56d",
"metadata": {},
"source": [
"# Tilelang Lazy JIT"
]
},
{
"cell_type": "markdown",
"id": "156e7370",
"metadata": {},
"source": [
"## Tensor Annotation"
]
},
{
"cell_type": "markdown",
"id": "b070c109",
"metadata": {},
"source": [
"Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n",
"\n",
"函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n",
"\n",
"例如,下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor\n",
"1. 它的每个维度都是编译期常量,如果改变,会触发重新编译\n",
"2. 它的类型必须是 T.float16\n",
"\n",
"DType 除了写确定的外,还可以写 Any 或者 None"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "60bf8954",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm(\n",
" A: T.Tensor[[int, int], T.float16],\n",
" B: T.Tensor[[int, int], T.float16],\n",
" out_dtype: T.dtype = T.float32,\n",
" block_M: int = 128,\n",
" block_N: int = 128,\n",
" block_K: int = 32\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" C = T.empty((M, N), out_dtype)\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
" B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
" C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n",
" T.clear(C_local)\n",
" for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
" T.copy(A[bx * block_M, k * block_K], A_shared)\n",
" T.copy(B[k * block_K, by * block_N], B_shared)\n",
" T.gemm(A_shared, B_shared, C_local)\n",
" T.copy(C_local, C[bx * block_M, by * block_N])\n",
" return C"
]
},
{
"cell_type": "markdown",
"id": "28f868fe",
"metadata": {},
"source": [
"直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ee13394a",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm(A, B)\n",
"\n",
"# check output is correct\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "c6705091",
"metadata": {},
"source": [
"更改调用的参数,如果编译器参数发生了变化,会触发重新编译:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d8aab5b7",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 1024, dtype=torch.float16, device='cuda')\n",
"C = gemm(A, B, block_M=64, block_N=64)"
]
},
{
"cell_type": "markdown",
"id": "ce6b7391",
"metadata": {},
"source": [
"你也可以手动调用 compile 函数编译 kernel\n",
"\n",
"1. `ker.compile` 编译 kernel\n",
"2. `ker.get_tir` 获取 tir\n",
"3. `ker.par_compile` 并行编译"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f3cf3a2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2025-11-25 17:29:46 [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n"
]
}
],
"source": [
"kernel = gemm.compile(A, B, block_M=64, block_N=64)\n",
"C = kernel(A, B)"
]
},
{
"cell_type": "markdown",
"id": "921761b5",
"metadata": {},
"source": [
"## More Tensor Annotation"
]
},
{
"cell_type": "markdown",
"id": "4539e54e",
"metadata": {},
"source": [
"### 用 macro 来分离实现"
]
},
{
"cell_type": "markdown",
"id": "ad96ba65",
"metadata": {},
"source": [
"接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "171d4fe6",
"metadata": {},
"outputs": [],
"source": [
"@T.macro\n",
"def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
" B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
" C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n",
" T.clear(C_local)\n",
" for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
" T.copy(A[bx * block_M, k * block_K], A_shared)\n",
" T.copy(B[k * block_K, by * block_N], B_shared)\n",
" T.gemm(A_shared, B_shared, C_local)\n",
" T.copy(C_local, C[bx * block_M, by * block_N])"
]
},
{
"cell_type": "markdown",
"id": "446a1acd",
"metadata": {},
"source": [
"### 用 T.dyn 标记动态 Shape\n",
"\n",
"当某些维度是动态的的时候,可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数,表示变量的名字"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a38aa95",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_dyn_K(\n",
" A: T.Tensor[[int, T.dyn['K']], T.float16], # noqa: F821\n",
" B: T.Tensor[[T.dyn['K'], int], T.float16], # noqa: F821\n",
"):\n",
" M, K = A.shape\n",
" K, N = B.shape\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n",
" return C"
]
},
{
"cell_type": "markdown",
"id": "c60fd346",
"metadata": {},
"source": [
"查看 lazy_jit 的函数签名,其中带有后缀`$` 的是不确定的编译期常量,带有 `$dyn` 的是运行时的变量"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c6992eb4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n",
" 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gemm_dyn_K.func.annot"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fe6cfdc8",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_dyn_K(A, B)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "2ee97bf7",
"metadata": {},
"source": [
"### 用 T.StridedTensor 标记带 stride 的 Tensor\n",
"\n",
"标记方法:T.StridedTensor[Shape, Stride, DType],每个 Shape 或 Stride 可以写\n",
"* int: 表示编译期常量\n",
"* T.dyn:表示运行时常量\n",
"\n",
"DType 可以写 None 或 Any"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9dde1dae",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"@tilelang.lazy_jit\n",
"def as_contingious(\n",
" A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]\n",
"):\n",
" M, N = A.shape\n",
" B = T.empty((M, N), A.dtype)\n",
" block_M = 128\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
" T.copy(\n",
" A[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N],\n",
" B[bx * block_M: (bx + 1) * block_M, by * block_N: (by + 1) * block_N]\n",
" )\n",
" return B"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "dec2c0a7",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 1024, device='cuda')\n",
"B = as_contingious(A[::2, ::2])\n",
"B_ref = A[::2, ::2].contiguous()\n",
"torch.testing.assert_close(B, B_ref)"
]
},
{
"cell_type": "markdown",
"id": "f5fb20d6",
"metadata": {},
"source": [
"## More Annotation"
]
},
{
"cell_type": "markdown",
"id": "890df0a2",
"metadata": {},
"source": [
"### 用 T.ptr 标注 Tensor\n",
"lazy_jit 允许你用 T.ptr 来声明一个 handle,但必须在函数内用 T.match_buffer 给它定义 shape"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0fc17af6",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_ptr(\n",
" A: T.ptr,\n",
" B: T.ptr,\n",
" M: int,\n",
" N: int,\n",
" K: int,\n",
"):\n",
" A = T.match_buffer(A, (M, K), T.float16)\n",
" B = T.match_buffer(B, (K, N), T.float16)\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
" return C"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8e52a554",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_ptr(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "6b19ef90",
"metadata": {},
"source": [
"### 用 T.int32 标注运行时变量\n",
"\n",
"lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量,这样,你可以写一个完全动态的 gemm,这和 triton 非常相似"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "c1e7598a",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def gemm_ptr_dyn(\n",
" A: T.ptr,\n",
" B: T.ptr,\n",
" M: T.int32,\n",
" N: T.int32,\n",
" K: T.int32,\n",
"):\n",
" A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n",
" B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n",
" C = T.empty((M, N), T.float32)\n",
" gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
" return C"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9e9a4c88",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, 512, dtype=torch.float16, device='cuda')\n",
"B = torch.randn(512, 256, dtype=torch.float16, device='cuda')\n",
"C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
"C_ref = (A @ B).float()\n",
"torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
]
},
{
"cell_type": "markdown",
"id": "39166cb4",
"metadata": {},
"source": [
"## 编译与并行编译"
]
},
{
"cell_type": "markdown",
"id": "8c6fbe08",
"metadata": {},
"source": [
"lazyjit 和原来的 jit 都支持并行编译\n",
"\n",
"为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "7222e57b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6d7f05cdfff412e9a527332438f7aa2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Elaborating: 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "14836065a21b41ae8fc34e8763ae49fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Parallel Compiling: 0%| | 0/8 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,\n",
" <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from itertools import product\n",
"\n",
"def get_configs():\n",
" return [\n",
" {\n",
" 'A': T.Tensor((1024, 1024), T.float32),\n",
" 'B': T.Tensor((1024, 1024), T.float32),\n",
" 'block_M': block_M,\n",
" 'block_N': block_N,\n",
" 'block_K': block_K,\n",
" }\n",
" for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
" ]\n",
"\n",
"gemm.par_compile(get_configs())"
]
},
{
"cell_type": "markdown",
"id": "5160d2cc",
"metadata": {},
"source": [
"## 更便利的 Macro"
]
},
{
"cell_type": "markdown",
"id": "be44afc4",
"metadata": {},
"source": [
"tilelang 的 macro 现在已经升级:\n",
"\n",
"1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n",
"2. 允许返回多个值\n",
"3. 允许嵌套,递归"
]
},
{
"cell_type": "markdown",
"id": "79575972",
"metadata": {},
"source": [
"### T.Ref 传递引用\n",
"\n",
"T.Ref 传递的引用可以 var 也可以是 Buffer 的索引"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90eaa6e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"\n",
"@T.prim_func\n",
"def foo(x_handle: T.handle):\n",
" x = T.match_buffer(x_handle, (2,), strides=(1,))\n",
" # with T.block(\"root\"):\n",
" bx = T.launch_thread(\"blockIdx.x\", 1)\n",
" tx = T.launch_thread(\"threadIdx.x\", 128)\n",
" ty = T.launch_thread(\"threadIdx.y\", 1)\n",
" tz = T.launch_thread(\"threadIdx.z\", 1)\n",
" with T.block(\"tilelang_root\"):\n",
" T.reads()\n",
" idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n",
" T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n",
" T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n",
" idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n",
" x[1] = T.float32(1.0)\n",
" _tmp: T.int32 = idx[0]\n",
" x[_tmp] = T.float32(1.0)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@T.macro\n",
"def macro_with_ref(x: T.Ref):\n",
" x = 1 # noqa: F841\n",
"\n",
"@T.prim_func\n",
"def foo(x: T.Tensor((2,))):\n",
" with T.Kernel(1) as _:\n",
" # 支持常量 index\n",
" macro_with_ref(x[1])\n",
"\n",
" # 也支持变量 index\n",
" idx = T.alloc_var(T.int32, 0)\n",
" macro_with_ref(x[idx])\n",
"\n",
"foo"
]
},
{
"cell_type": "markdown",
"id": "7bb447a2",
"metadata": {},
"source": [
"### 当作参数传递\n",
"\n",
"你可以把 macro 当做参数传递"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "dc7bb779",
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"def element_wise(\n",
" A: T.Tensor[[T.dyn], Any],\n",
" fn,\n",
"):\n",
" N, = A.shape\n",
" B = T.empty((N,), dtype=A.dtype)\n",
" block_N = 128\n",
" with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
" for i in T.Parallel(block_N):\n",
" idx = bx * block_N + i\n",
" B[idx] = fn(A[idx])\n",
" return B\n",
"@T.macro\n",
"def add_one(x):\n",
" return x + 1"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a89fdb44",
"metadata": {},
"outputs": [],
"source": [
"A = torch.randn(1024, device='cuda')\n",
"B = element_wise(A, add_one)\n",
"B_ref = A + 1\n",
"torch.testing.assert_close(B, B_ref)"
]
},
{
"cell_type": "markdown",
"id": "ef6e403a",
"metadata": {},
"source": [
"### Macro 递归\n",
"\n",
"虽然不知道有没有这种需求,但 macro 是可以递归的,但要求终止条件编译期间确定"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7703cab5",
"metadata": {},
"outputs": [],
"source": [
"@T.macro\n",
"def n31(x, var: T.Ref):\n",
" if x == 1:\n",
" pass\n",
" elif x % 2 == 0:\n",
" var = var // 2\n",
" n31(x // 2, var)\n",
" else:\n",
" var = var * 3 + 1\n",
" n31(x * 3 + 1, var)\n",
"\n",
"@tilelang.lazy_jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "542ddd4e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([18], device='cuda:0', dtype=torch.int32)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = torch.tensor([100], dtype=torch.int32, device='cuda')\n",
"foo(A, 5)\n",
"A"
]
},
{
"cell_type": "markdown",
"id": "dc30c2d2",
"metadata": {},
"source": [
"### Macro 返回多个值"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5a2388f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"# from tvm.script import tir as T\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" # with T.block(\"root\"):\n",
" x = T.launch_thread(\"blockIdx.x\", 32)\n",
" tx = T.launch_thread(\"threadIdx.x\", 128)\n",
" ty = T.launch_thread(\"threadIdx.y\", 1)\n",
" tz = T.launch_thread(\"threadIdx.z\", 1)\n",
" with T.block(\"tilelang_root\"):\n",
" T.reads()\n",
" T.writes()\n",
" s: T.int32 = T.sin(x)\n",
" c: T.int32 = T.cos(x)\n",
" a: T.int32 = s + c\n",
" b: T.int32 = s - c\n",
" T.evaluate(0)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@T.macro\n",
"def sincos(x):\n",
" return T.sin(x), T.cos(x)\n",
"\n",
"@T.prim_func\n",
"def foo():\n",
" with T.Kernel(32) as x:\n",
" s, c = sincos(x)\n",
" a = s + c # noqa: F841\n",
" b = s - c # noqa: F841\n",
"foo"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tilelang-dev_0",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -252,9 +252,9 @@ def test_marco_return():
c = macro_return_expr(4.0)
d = macro_apply_func(5.0, lambda x: x * 2.0)
check(a, (int, float, T.PrimExpr))
check(b, T.PrimExpr)
check(c, T.PrimExpr)
check(d, T.PrimExpr)
check(b, (int, float, T.PrimExpr))
check(c, (int, float, T.PrimExpr))
check(d, (int, float, T.PrimExpr))
def test_prim_func_generator():
......
from dataclasses import dataclass, field
import tilelang.testing
import tilelang
import tilelang.language as T
from typing import Any
from itertools import product
import torch
def _gemm_impl():
@T.macro
def gemm_impl(
A: T.Tensor[[int, int], Any],
B: T.Tensor[[int, int], Any],
C: T.Tensor[[int, int], Any],
out_dtype: T.dtype,
block_M: int,
block_N: int,
block_K: int,
):
dtype = A.dtype
M, K = A.shape
K, N = B.shape
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[bx * block_M, by * block_N])
return gemm_impl
def test_jit2_gemm_annot():
@tilelang.lazy_jit
def gemm(
A: T.Tensor[[int, int], Any],
B: T.Tensor[[int, int], Any],
out_dtype: T.dtype = T.float32,
block_M: int = 64,
block_N: int = 64,
block_K: int = 32,
):
M, K = A.shape
K, N = B.shape
C = T.empty(M, N, dtype=out_dtype)
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
return C
prod = product([T.float16, T.float32], [T.float32])
gemm.par_compile([{
'A': T.Tensor((1024, 1024), dtype=in_dtype),
'B': T.Tensor((1024, 1024), dtype=in_dtype),
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
C_ref = out_dtype(A @ B)
C = gemm(A, B)
torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)
def test_jit2_gemm_ptr():
@tilelang.lazy_jit
def gemm_ptr(
A: T.ptr,
B: T.ptr,
C: T.ptr,
M: int,
N: int,
K: int,
dtype: T.dtype,
out_dtype: T.dtype,
block_M: int = 64,
block_N: int = 64,
block_K: int = 32,
):
A = T.make_tensor(A, (M, K), dtype)
B = T.make_tensor(B, (K, N), dtype)
C = T.make_tensor(C, (M, N), out_dtype)
_gemm_impl()(A, B, C, out_dtype, block_M, block_N, block_K)
prod = product([T.float16, T.float32], [T.float32])
gemm_ptr.par_compile([{
'A': T.ptr(),
'B': T.ptr(),
'C': T.ptr(),
'M': 1024,
'N': 1024,
'K': 1024,
'dtype': in_dtype,
'out_dtype': out_dtype
} for in_dtype, out_dtype in prod])
for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch()
out_dtype = out_dtype.torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
B = torch.randn(1024, 1024, dtype=in_dtype, device='cuda')
C_ref = out_dtype(A @ B)
C = torch.empty(1024, 1024, dtype=out_dtype, device='cuda')
gemm_ptr(A, B, C, 1024, 1024, 1024, in_dtype, out_dtype)
torch.testing.assert_close(C, C_ref, atol=1e-2, rtol=1e-2)
def test_jit2_annot():
from tilelang.language.v2.annot import Annot, ArgVarTable
from tilelang.language.v2.builder import Builder
import traceback
@dataclass
class AnnotTest:
annot: Annot
promote: Any
match_ok: list[Any] = field(default_factory=list)
match_ng: list[Any] = field(default_factory=list)
tests = [
AnnotTest(
annot=T.Tensor[[int, int], T.float32],
promote=False,
match_ok=[torch.randn(1, 1, dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float32)],
match_ng=[
torch.randn(1, 1, dtype=torch.float16),
T.Tensor(1, dtype=T.float32),
T.Tensor((1, 1), dtype=T.float16),
],
),
AnnotTest(
annot=T.Tensor[[int], Any],
promote=False,
match_ok=[
torch.randn(12, dtype=torch.float32),
torch.randn(12, dtype=torch.float16),
T.Tensor((1,), dtype=T.float32),
T.Tensor((1,), dtype=T.float16),
],
match_ng=[torch.randn((1, 1), dtype=torch.float32),
T.Tensor((1, 1), dtype=T.float16)]),
AnnotTest(
annot=T.Tensor[[int, 1], Any],
promote=False,
match_ok=[
torch.randn(12, 1, dtype=torch.float32),
torch.randn(12, 1, dtype=torch.float16),
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
AnnotTest(
annot=T.Tensor[[T.dyn, 1], Any],
promote=False,
match_ok=[
torch.randn(12, 1, dtype=torch.float32),
torch.randn(12, 1, dtype=torch.float16),
T.Tensor((12, 1), T.float32),
T.Tensor((12, 1), T.float16),
],
match_ng=[torch.randn(12, 12, dtype=torch.float32),
T.Tensor((12, 12), T.float32)]),
AnnotTest(
annot=T.Tensor[[1024, 1024], T.float32],
promote=True,
),
AnnotTest(annot=T.dyn[int, 'X'], promote=False, match_ok=[1, 2, 3, 4]),
AnnotTest(annot=T.dyn, promote=False, match_ok=[1, 2, 3, 4])
]
for test in tests:
promote = test.annot.promote()
promoted = promote is not None
if promoted != test.promote:
raise AssertionError(
f'Promote mismatch for {test.annot}: expected {test.promote}, got {promoted}')
with Builder().prim_func('_test'):
for match_ok in test.match_ok:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ok, vt)
except Exception as e:
traceback.print_exc()
raise AssertionError(
f'Match failed for {test.annot} with value {match_ok}: {e}') from e
for match_ng in test.match_ng:
try:
vt = ArgVarTable()
test.annot.create_prim_func_arg('arg', match_ng, vt)
raise AssertionError(
f'Match unexpectedly succeeded for {test.annot} with value {match_ng}')
except Exception:
pass
def test_jit2_many_annot():
@T.macro
def copy_impl(A, B):
M, N = A.shape
M_, N_ = B.shape
assert M == M_, f"M mismatch {M} {M_}"
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
@tilelang.lazy_jit
def copy1(
A: T.Tensor[[int, int], T.float32],
B: T.Tensor[[int, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy2(
A: T.Tensor[[128, 128], T.float32],
B: T.Tensor[[128, 128], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy3(
A: T.Tensor[[int, 128], T.float32],
B: T.Tensor[[int, 128], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy4(
A: T.Tensor[[T.dyn, int], T.float32],
B: T.Tensor[[T.dyn, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy5(
A: T.StridedTensor[[int, int], [int, int], T.float32],
B: T.StridedTensor[[int, int], [int, int], T.float32],
):
copy_impl(A, B)
@tilelang.lazy_jit
def copy6(
A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
B: T.StridedTensor[[T.dyn, int], [int, int], T.float32],
):
copy_impl(A, B)
for copy in [copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
B = torch.empty(128, 128, device='cuda')
copy(A, B)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
B = torch.randn(128, 2, 128, 2, device='cuda')
copy(A[:, 0, :, 0], B[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B[:, 0, :, 0])
def test_jit2_return():
@T.macro
def copy_impl(A):
M, N = A.shape
B = T.empty(M, N, dtype=A.dtype)
M, N = A.shape
M_, N_ = B.shape
assert M == M_, f"M mismatch {M} {M_}"
assert N == N_, f"N mismatch {N} {N_}"
# assert tuple(A.shape) == tuple(B.shape), f"Invalid tensor shape: {A.shape}, {B.shape}"
with T.Kernel(T.ceildiv(M, 128), T.ceildiv(N, 128), threads=128) as (bx, by):
T.copy(A[bx * 128:bx * 128 + 128, by * 128:by * 128 + 128], B[bx * 128:bx * 128 + 128,
by * 128:by * 128 + 128])
return B
@tilelang.lazy_jit
def copy0(A: T.Tensor[[int, int], Any]):
return copy_impl(A)
@tilelang.lazy_jit
def copy1(A: T.Tensor[[int, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy2(A: T.Tensor[[128, 128], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy3(A: T.Tensor[[int, 128], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy4(A: T.Tensor[[T.dyn, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy5(A: T.StridedTensor[[int, int], [int, int], T.float32],):
return copy_impl(A)
@tilelang.lazy_jit
def copy6(A: T.StridedTensor[[T.dyn, int], [int, int], T.float32],):
return copy_impl(A)
for copy in [copy0, copy1, copy2, copy3, copy4]:
A = torch.randn(128, 128, device='cuda')
B = copy(A)
assert torch.equal(B, A)
for copy in [copy5, copy6]:
A = torch.randn(128, 2, 128, 2, device='cuda')
B = copy(A[:, 0, :, 0])
assert torch.equal(A[:, 0, :, 0], B)
def test_jit2_deepseek_deepgemm():
@tilelang.lazy_jit
def deep_gemm(
A: T.Tensor[[int, int], T.float8_e4m3],
B: T.Tensor[[int, int], T.float8_e4m3],
scales_a: T.Tensor[[int, int], T.float32],
scales_b: T.Tensor[[int, int], T.float32],
out_dtype: T.dtype = T.bfloat16,
accum_dtype: T.dtype = T.float32,
block_N: int = 128,
block_M: int = 128,
block_K: int = 128,
):
# A: [M, K]
# B: [N, K]
# scales_a: [M, K // 128]
# scales_b: [N, K // 128]
# C: [M, N]
group_size = 128
in_dtype = A.dtype
M, K = A.shape
N, K = B.shape
C = T.empty(M, N, dtype=out_dtype)
assert out_dtype in [
T.bfloat16, T.float32
], f"Expect out_dtype to be one of [T.float16, T.float32], got {out_dtype}"
assert scales_a.shape == [M, T.ceildiv(K, group_size)
], f"Expect scales_a shape to be f{[M, T.ceildiv(K, group_size)]}"
assert scales_b.shape == [N, T.ceildiv(K, group_size)
], f"Expect scales_b shape to be f{[N, T.ceildiv(K, group_size)]}"
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), in_dtype)
B_shared = T.alloc_shared((block_N, block_K), in_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
scale_C_shared = T.alloc_shared((block_M,), T.float32)
C_local = T.alloc_fragment((block_M, block_K), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * scale_C_shared[i]
T.clear(C_local)
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return C
# def ceildiv(a, b):
# return (a + b - 1) // b
# def ref_deepgemm_fp8(A_fp8, B_fp8, A_scale, B_scale, out_dtype):
# # A_scale: (M, K//128) ==> (M//128, K//128, 128)
# # B_scale: (N//128, K//128) ==> (N//128, K//128, 128)
# # A_fp8: (M, K)
# # B_fp8: (N, K)
# # out_dtype: float16 or float32
# # return C: (M, N)
# M, N, K = A_fp8.shape[0], B_fp8.shape[0], A_fp8.shape[1]
# A_scales = A_scale.view(M // 128, 128, K // 128).permute(0, 2, 1)
# B_scales = B_scale.repeat_interleave(128, dim=1).view(N // 128, K // 128, 128)
# C = torch.zeros(M, N, device="cuda", dtype=out_dtype)
# c_acc = torch.zeros(128, 128, device="cuda", dtype=torch.float32)
# for i in range(ceildiv(M, 128)):
# for j in range(ceildiv(N, 128)):
# c_acc.zero_()
# for k in range(ceildiv(K, 128)):
# c = torch._scaled_mm(
# A_fp8[i * 128:(i + 1) * 128, k * 128:(k + 1) * 128],
# B_fp8[j * 128:(j + 1) * 128, k * 128:(k + 1) * 128].T,
# scale_a=A_scales[i, k].view(128, 1).contiguous(),
# scale_b=B_scales[j, k].view(1, 128).contiguous(),
# out_dtype=torch.bfloat16)
# c_acc += c.to(torch.float32)
# C[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = c_acc.to(out_dtype)
# return C
# M, N, K = 1024, 1024, 8192
# A = torch.randn((M, K), dtype=torch.float8_e4m3fn, )
if __name__ == '__main__':
tilelang.testing.main()
......@@ -120,7 +120,7 @@ def _load_tile_lang_lib():
if env.SKIP_LOADING_TILELANG_SO == "0":
_LIB, _LIB_PATH = _load_tile_lang_lib()
from .jit import jit, JITKernel, compile # noqa: F401
from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401
from .profiler import Profiler # noqa: F401
from .cache import clear_cache # noqa: F401
......
......@@ -141,6 +141,10 @@ def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
if var in func.buffer_map:
tensor_types.append(KernelParam.from_buffer(func.buffer_map[var]))
else:
if var.dtype == 'handle':
raise ValueError(
f'Handle parameter {var} must be mapped to a buffer.\n'
f'Please use T.tensor({var.name}, shape=..., dtype=...) to map it to a buffer.')
tensor_types.append(KernelParam.from_var(var))
return tensor_types
......
......@@ -16,13 +16,15 @@ from typing import (
Literal,
)
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.language.v2 import PrimFunc, PrimFuncCreater, prim_func
from tilelang.language.v2.annot import Annot
from tvm.target import Target
from tilelang.jit.kernel import JITKernel
......@@ -40,6 +42,7 @@ logger = getLogger(__name__)
_P = ParamSpec('_P')
_KP = ParamSpec('_KP')
_T = TypeVar('_T')
_Ret = TypeVar('_Ret')
def compile(
......@@ -74,10 +77,19 @@ def compile(
Additional keyword arguments to pass to the Compiler PassContext.
Refer to `tilelang.transform.PassConfigKey` for supported options.
"""
assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}"
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
if hasattr(func, 'out_idx_override'):
if func.out_idx_override is not None and out_idx is not None:
raise ValueError(
"Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors"
)
out_idx = func.out_idx_override or out_idx
# This path is not a performance critical path, so we can afford to convert the target.
target = Target(determine_target(target))
......@@ -176,8 +188,76 @@ def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
@dataclass
class JITImpl(Generic[_P, _KP, _T]):
func: Callable[_P, _T] | PrimFunc[_KP, _T]
class JITImpl(Generic[_P, _KP, _T, _Ret]):
'''
Detailed Just-In-Time wrapper for TileLang programs.
This dataclass encapsulates the configuration and runtime helpers used by the
top-level `jit` and `jit2` decorators. It represents a configured JIT
"factory" that can (a) elaborate TileLang/PrimFunc creators into concrete
TIR (PrimFunc), (b) compile those TIR functions into runnable kernels via
the TVM bridge, (c) cache compiled kernels keyed by call-site arguments
(and optional tuning parameters), and (d) provide parallel compilation
helpers for batch autotuning workflows.
Attributes
----------
out_idx : list[int] | int | None
Which output tensor(s) of the compiled kernel should be returned to the
caller. Accepts a single index, a list of indices, or None to return all.
execution_backend : Literal["dlpack", "ctypes", "cython"]
Backend used for exchanging arguments and executing the generated kernel.
target : str | tvm.target.Target
TVM compilation target (e.g. "cuda", "llvm", or "auto").
target_host : str | tvm.target.Target | None
Host target used for cross-compilation, or None to infer/default.
verbose : bool
Enable verbose messages during compilation/build.
pass_configs : dict[str, Any] | None
Extra TVM pass configuration options forwarded to the compiler's
PassContext.
debug_root_path : str | None
If provided, compiled kernel source and the elaborated Python program
are written to this directory to ease debugging and inspection.
compile_flags : list[str] | str | None
Additional flags passed to the compiler. A single string will be converted
to a single-element list.
func_source : str
Original Python source string from which the PrimFunc or creator was
derived. Used for diagnostics and debug dumps.
signature : inspect.Signature
Function signature of the original Python function (useful for tooling).
v2 : bool
Indicates whether the object wraps a "v2" PrimFunc creator (True) or a
plain callable / PrimFunc (False). v2-mode enables argument conversion
hooks and a distinct cache keying strategy.
func : Callable | PrimFunc | PrimFuncCreater
The underlying object: either a user function that returns a PrimFunc
(creator), a PrimFuncCreater, or an already-constructed PrimFunc.
For presentation/readability the function is stored last in the dataclass.
Behavioral summary
------------------
- get_tir(*args, **kwargs)
Converts provided call-site arguments into a concrete PrimFunc. If the
wrapped object is a PrimFuncCreater or a user callable, it is invoked
with the given arguments. If the wrapped object is already a PrimFunc,
it is returned as-is.
- compile(...)
A convenience wrapper that elaborates and immediately compiles a single
PrimFunc into a JITKernel using the module-level `compile` function.
When `debug_root_path` is set, the compiled C kernel and the source
Python program are saved for inspection.
- par_compile(configs, ...)
Accepts an iterable of configs (either dicts mapping keyword args or
tuples mapping to positional args). Each config is elaborated to a
PrimFunc and the resulting set is compiled in parallel via the
module-level `par_compile` helper. Returns a list of JITKernel objects
in the same order as the provided configs.
'''
out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
target: str | Target
......@@ -188,6 +268,14 @@ class JITImpl(Generic[_P, _KP, _T]):
compile_flags: list[str] | str | None
func_source: str
signature: inspect.Signature
lazy_jit: bool
# place func at the last element for better __repr__
func: Callable[_P, _T] | PrimFunc[_KP, _T]
@property
def annot(self) -> dict[str, Annot]:
assert self.lazy_jit, "annot is only support in @tilelang.jit2"
return self.func.func_annot.annots
def __post_init__(self):
if self.debug_root_path is not None and not path.isabs(self.debug_root_path):
......@@ -197,21 +285,47 @@ class JITImpl(Generic[_P, _KP, _T]):
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
self._kernel_cache: dict[tuple, Kernel] = {}
self._tuner_cache: dict[tuple, Kernel] = {}
def get_tir(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_KP, _T]:
program_result_source = self.func
if isinstance(program_result_source, PrimFunc):
program_result = program_result_source
elif callable(program_result_source):
program_result = program_result_source(*args, **kwargs)
"""
Retrieve a TIR (Tensor Intermediate Representation) PrimFunc from the stored callable or object.
"""
if isinstance(self.func, PrimFuncCreater):
tir = self.func(*args, **kwargs)
elif isinstance(self.func, PrimFunc):
tir = self.func
elif callable(self.func):
tir = self.func(*args, **kwargs)
else:
raise ValueError(f"Invalid function type: {type(program_result_source)}")
return program_result
raise ValueError(f"Invalid function type: {type(self.func)}")
assert isinstance(tir, PrimFunc), f"target function must be a PrimFunc but got {type(tir)}"
return tir
def par_compile(self,
configs: Iterable[dict[str, Any] | tuple[str, Any]],
num_workers: int = None,
ignore_error: bool = False) -> list[JITKernel[_KP, _T]]:
"""
Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels.
Parameters
----------
configs : Iterable[Union[dict[str, Any], tuple[Any, ...]]]
The configurations to elaborate and compile. Each config can be either
a dictionary mapping keyword arguments to values, or a tuple of positional
arguments.
num_workers : int, optional
Number of parallel workers to use for compilation. Defaults to None,
which lets the system decide.
ignore_error : bool, optional
If True, compilation errors for individual configs will be logged
as warnings and the corresponding result will be None. If False,
any compilation error will raise an exception. Defaults to False.
Returns
-------
List[JITKernel]
A list of compiled JITKernel objects corresponding to the provided configs.
"""
configs = list(configs)
funcs = []
for cfg in tqdm(configs, desc='Elaborating'):
......@@ -233,7 +347,7 @@ class JITImpl(Generic[_P, _KP, _T]):
num_workers=num_workers,
ignore_error=ignore_error)
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
def compile(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
func = self.get_tir(*args, **kwargs)
kernel_result = compile(
func,
......@@ -261,12 +375,34 @@ class JITImpl(Generic[_P, _KP, _T]):
return kernel_result
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel[_KP, _T]:
def parse_cache_key(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
return self.func.func_annot.parse_key(*args, **kwargs, **tune_params)
else:
tune_params = kwargs.pop('__tune_params', {})
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
return key
def convert_kernel_args(self, *args: _P.args, **kwargs: _P.kwargs):
if isinstance(self.func, PrimFuncCreater):
tune_params = kwargs.pop('__tune_params', {})
return self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
else:
raise NotImplementedError(
"convert_arg_to_kernel_args is only implemented for PrimFuncCreater.")
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
# Separate out the tuning parameters from the user's kwargs
tune_params = kwargs.pop('__tune_params', {})
# Whether to return the compile arguments (out_idx, target, target_host, etc.) for autotuner cache
return_compile_arguments = kwargs.pop('__return_compile_arguments', False)
if return_compile_arguments:
logger.warning(
"`__return_compile_arguments` is deprecated and will be removed in future versions."
)
compile_args = {
'out_idx': self.out_idx,
'execution_backend': self.execution_backend,
......@@ -278,19 +414,27 @@ class JITImpl(Generic[_P, _KP, _T]):
}
return compile_args
key_args_tuple = args
key_kwargs_tuple = tuple(sorted(kwargs.items()))
tuned_key_kwargs_tuple = tuple(sorted(tune_params.items()))
key = (key_args_tuple, key_kwargs_tuple, tuned_key_kwargs_tuple)
key = self.parse_cache_key(*args, **kwargs)
if key not in self._kernel_cache:
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
tune_params = kwargs.pop('__tune_params', {})
kernel = self._kernel_cache.get(key, None)
if kernel is None:
kernel = self.compile(*args, **kwargs, **tune_params)
self._kernel_cache[key] = kernel
if self.lazy_jit:
args = self.func.func_annot.convert_to_kernel_args(*args, **kwargs, **tune_params)
return kernel(*args)
else:
return kernel
return self._kernel_cache[key]
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"]
@overload
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T]:
def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]:
...
......@@ -300,13 +444,12 @@ def jit(
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T]]:
) -> Callable[[Callable[_P, PrimFunc[_KP, _T]]], JITImpl[_P, _KP, _T, JITKernel[_KP, _T]]]:
...
......@@ -316,8 +459,7 @@ def jit( # This is the new public interface
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc",
"torch"] = "auto",
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
......@@ -358,12 +500,12 @@ def jit( # This is the new public interface
compile_flags = [compile_flags]
def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
if isinstance(func, PrimFunc):
if isinstance(func, (PrimFunc, PrimFuncCreater)):
orig_func = func.orig_func
else:
orig_func = func
return JITImpl(
func,
func=func,
out_idx=out_idx,
execution_backend=execution_backend,
target=target,
......@@ -374,9 +516,70 @@ def jit( # This is the new public interface
compile_flags=compile_flags,
func_source=inspect.getsource(orig_func),
signature=inspect.signature(orig_func),
)
lazy_jit=False)
if func is not None:
return decorator(func)
else:
return decorator
@overload
def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]:
...
@overload
def lazy_jit(
*,
out_idx: Any = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None
) -> Callable[[Callable[_KP, _T]], JITImpl[_KP, _KP, _T, _T]]:
...
def lazy_jit(
func: Callable[_P, _T] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: ExecutionBackend = "auto",
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
debug_root_path: str | None = None,
compile_flags: list[str] | str | None = None,
):
if isinstance(compile_flags, str):
compile_flags = [compile_flags]
compile_args = dict(
out_idx=None,
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose,
pass_configs=pass_configs,
debug_root_path=debug_root_path,
compile_flags=compile_flags)
def decorator(func: Callable[_P, _T]):
pf: PrimFunc[_P, _T] | PrimFuncCreater[_P, _T] = prim_func(func, generator=True)
# if isinstance(pf, PrimFunc):
# compile_args.pop('debug_root_path', None)
# return compile(pf, **compile_args)
# else:
return JITImpl(
func=pf,
**compile_args,
func_source=inspect.getsource(pf.orig_func),
signature=inspect.signature(pf.orig_func),
lazy_jit=True)
return decorator(func) if func is not None else decorator
......@@ -106,6 +106,9 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
params = func.params
buffer_map = func.buffer_map
dynamic_symbolic_map = {}
for i, param in enumerate(params):
if isinstance(param, tir.Var) and (param not in dynamic_symbolic_map):
dynamic_symbolic_map[param] = (2, i, -1)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
......@@ -217,7 +220,14 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
if (str(s) == str(key)):
ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[
key]
shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
if ref_id == 2:
shape.append(inputs[ref_tensor_idx])
elif ref_id == 0:
shape.append(
tensor_list[ref_tensor_idx].shape[ref_shape_idx])
elif ref_id == 1:
shape.append(
tensor_list[ref_tensor_idx].stride()[ref_shape_idx])
else: # Already converted to Python int during initialization
shape.append(s)
......
......@@ -13,16 +13,15 @@ from . import overrides as _overrides # noqa: F401
from .v2 import * # noqa: F401
from .tir.ir import * # noqa: F401
from tilelang.layout import Layout, Fragment # noqa: F401
from .proxy import (
ptr, # noqa: F401
make_tensor, # noqa: F401
from .proxy import ptr, make_tensor # noqa: F401
from .v2.annot import (
Buffer, # noqa: F401
Tensor, # noqa: F401
StridedTensor, # noqa: F401
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
Ref, # noqa: F401
dyn, # noqa: F401
)
from .loop import (
Parallel, # noqa: F401
......@@ -56,6 +55,7 @@ from .allocate import (
alloc_wgmma_desc, # noqa: F401
alloc_tcgen05_smem_desc, # noqa: F401
alloc_tcgen05_instr_desc, # noqa: F401
empty, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
......
......@@ -14,8 +14,7 @@ Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope.
"""
from __future__ import annotations
from typing import overload, Literal
from typing import TypeVarTuple, TypeVar, overload, Literal, Unpack, Callable
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
......@@ -23,9 +22,16 @@ from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
from .v2.dtypes import dtype as tl_dtype
from .v2.builder import OutTensor
from .v2.annot import Tensor, SharedBuffer, LocalBuffer, FragmentBuffer
_Shapes = TypeVarTuple('_Shapes')
_DType = TypeVar('_DType')
def alloc_shared(shape, dtype, scope="shared.dyn"):
def alloc_shared(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="shared.dyn") -> SharedBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a shared memory buffer for inter-thread communication.
Args:
......@@ -43,7 +49,9 @@ def alloc_shared(shape, dtype, scope="shared.dyn"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_local(shape, dtype, scope="local"):
def alloc_local(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local") -> LocalBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a local memory buffer for thread-private storage.
Args:
......@@ -57,7 +65,9 @@ def alloc_local(shape, dtype, scope="local"):
return T.alloc_buffer(shape, dtype, scope=scope)
def alloc_fragment(shape, dtype, scope="local.fragment"):
def alloc_fragment(shape: tuple[Unpack[_Shapes]],
dtype: _DType,
scope="local.fragment") -> FragmentBuffer[Callable[[Unpack[_Shapes]]], _DType]:
"""Allocate a fragment memory buffer for specialized operations.
Args:
......@@ -256,3 +266,21 @@ def alloc_tcgen05_instruction_desc(dtype: str = "uint32"):
# Alias: short name consistent with imports
def alloc_tcgen05_instr_desc(dtype: str = "uint32"):
return alloc_tcgen05_instruction_desc(dtype)
@overload
def empty(shape: tuple[Unpack[_Shapes]],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
def empty(*shape: Unpack[_Shapes],
dtype: str = 'float32') -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
if len(shape) == 1 and isinstance(shape[0], (tuple, list)):
return OutTensor(shape[0], dtype)
elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
return OutTensor(shape[0], shape[1])
elif all([isinstance(x, (int, PrimExpr)) for x in shape]):
return OutTensor(shape, dtype)
else:
raise RuntimeError(f'Invalid shape {shape}')
from .builder import prim_func, macro, PrimFunc # noqa: F401
from .builder import prim_func, macro, PrimFunc, PrimFuncCreater, Ref # noqa: F401
from .dtypes import *
from __future__ import annotations
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from tvm import tir
from tvm.ir.expr import PrimExpr
from tvm.script.ir_builder.tir import buffer
from typing import Any, Callable, Literal, TypeVar, ParamSpec, Generic, TypeVarTuple, Unpack, TYPE_CHECKING, _GenericAlias, Self
from collections.abc import Sequence
from .dtypes import AnyDType
from . import dtypes as dt
import tvm.script.ir_builder.tir as tb_tir
from tvm.script.ir_builder import IRBuilder
import torch
import inspect
_Shapes = TypeVarTuple('_Shapes')
_Shape = ParamSpec('_Shape')
_Stride = ParamSpec('_Stride')
_DType = TypeVar('_DType')
Scope = Literal['global', 'shared.dyn', 'local', 'local.fragment']
class Annot(ABC):
'''
Base class for tilelang kernel annotations
Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities:
1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
2. parse the argument value into a hash key for jit caching
3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
def is_kernel_arg(self) -> bool:
'''
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
'''
return False
@abstractmethod
def with_name(self: Self, name) -> Self:
pass
@abstractmethod
def get_key_parser(self) -> Callable[[str, Any], tuple[Any, ...]]:
'''
Return a parser function that converts the argument value into a hash key for jit caching
'''
@abstractmethod
def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable) -> tir.Var | tir.Buffer:
'''
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
def promote(self) -> TIRAnnot | None:
'''
Try to promote the annotation into a FixedAnnot if possible
Return None if not promotable
'''
return None
@dataclass
class ArgVarTable:
'''
ArgVarTable is used to manage the mapping from argument names to tir.Var objects
'''
var_tab: dict[str, tir.Var] = field(default_factory=dict)
tmp_name_idx: int = 0
def get_or_create_var(self, name: str, dtype: dt.dtype) -> tir.Var:
if not name:
name = self.create_tmp_name()
if name not in self.var_tab:
self.var_tab[name] = tir.Var(name, dtype)
return self.var_tab[name]
def create_tmp_name(self) -> str:
name = f'varg_{self.tmp_name_idx}'
self.tmp_name_idx += 1
return name
@dataclass
class Value(Annot):
kind: Literal['static', 'dynamic'] = 'dynamic'
name: str | None = None
dtype: dt.dtype | None = dt.int32
value: int | tir.Var | None = None
creator: Callable[[], Any] | None = None
def is_kernel_arg(self) -> bool:
return self.kind == 'dynamic'
@classmethod
def from_value(cls, value: Any, prefer_name: str = None) -> Value:
if isinstance(value, int):
# handle A: T.Tensor[[1024, 1024], ...]
return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value)
elif isinstance(value, float):
return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value)
elif isinstance(value, tir.Var):
# handle A: T.Tensor[[M, N, K], ...]
return Value(kind='dynamic', name=value.name, dtype=value.dtype, value=value)
elif isinstance(value, dt.dtype):
# handle A: T.float32
return Value(kind='dynamic', name=prefer_name, dtype=value, value=None)
elif isinstance(value, Value):
# handle A: T.dyn
return value
elif isinstance(value, TypeVar):
return Value(kind='static', name=value.__name__, value=None)
elif value is Any or value is None or value is dt.dtype or isinstance(
value, (type, _GenericAlias)):
# A # no annotation
# A: Any
# A: _T
# A: dt.dtype
# A: tuple[...]
return Value(kind='static', name=prefer_name, value=None)
else:
raise TypeError(f"Unsupported Value annotation: {value!r}")
def with_name(self, name: str) -> Value:
return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value)
def get_key_parser(self):
if self.kind == 'static':
if self.value is not None:
expected_value = self.value
def key_parser(name: str, target: Any):
assert target == expected_value
return target
return key_parser
else:
return lambda name, target: (target,)
else:
return lambda name, target: (None,)
def parse_key(self, target: Any):
return self.get_key_parser()(target)
def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable, create_arg: bool = True):
if self.kind == 'static':
if self.value:
assert self.value == value, f"static value mismatch for {name}: expected {self.value}, got {value}"
return value
else:
name = self.name or name or vt.create_tmp_name()
if self.value is not None:
arg = self.value
elif self.creator is not None:
arg = self.creator()
else:
arg = vt.get_or_create_var(name, self.dtype)
return tb_tir.arg(name, arg) if create_arg else arg
def __repr__(self):
if self.kind == 'static':
if self.value is not None:
return repr(self.value)
else:
return (str(self.name) or '$unnamed') + '$'
else:
if self.value is not None:
return repr(self.value)
elif self.creator is not None:
return repr(self.creator())
else:
return (str(self.name) or '$unnamed') + '$dyn'
def _canonicalize_dtype(val: Any) -> dt.dtype | None:
if val == Any or val is None:
return None
if isinstance(val, TypeVar):
return None
return dt.dtype(val)
def _canonicalize_shape(shape: Sequence[Any]) -> list[Value]:
if shape is None or shape is Any:
return None
return [Value.from_value(dim) for _, dim in enumerate(iterable=shape)]
def _canonicalize_strides(strides: Sequence[Any]) -> list[Value]:
if strides is None or strides is Any:
return None
return [Value.from_value(dim) for _, dim in enumerate(strides)]
def _shape_with_name(shape: Sequence[Value], base_name: str) -> list[Value]:
if shape is None:
return None
res = []
for i, dim in enumerate(shape):
dim = dim.with_name(f'{base_name}_{i}')
res.append(dim)
return res
def _try_convert_static_shape(shape: Sequence[Value]):
if shape is None:
return None
res = []
for s in shape:
if s.kind == 'static' and s.value is not None or s.kind == 'dynamic' and s.value is not None:
res.append(s.value)
if len(res) == len(shape):
return res
@dataclass
class BufferAnnot(Annot):
shape: tuple = None
strides: tuple = None
dtype: dt.dtype = None
def is_kernel_arg(self) -> bool:
return True
@property
def scope(self):
return 'global'
def __call__(
self,
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
data=None,
strides=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
return buffer(
shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope or self.scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
def __getitem__(self, params):
shape, dtype = params
if not isinstance(shape, (tuple, list)):
shape = (shape,)
shape = _canonicalize_shape(shape)
dtype = _canonicalize_dtype(dtype)
return self.__class__(shape, strides=self.strides, dtype=dtype)
def with_name(self, name: str):
shape = _shape_with_name(self.shape, base_name=f'{name}_shape')
strides = _shape_with_name(self.strides, base_name=f'{name}_stride')
return self.__class__(shape, strides, self.dtype)
def get_key_parser(self):
raw_shapes = True
if self.shape is not None:
raw_shapes = False
shape_len = len(self.shape)
static_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static']
# static_fixed_shape_idx = [i for i, dim in enumerate(self.shape) if dim.kind == 'static' and dim.value is not None]
# static_fixed_shape_values = [dim.value for dim in self.shape if dim.kind == 'static' and dim.value is not None]
raw_strides = True
if self.strides is not None:
raw_strides = False
strides_len = len(self.strides)
strides_shape_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static']
# static_fixed_strides_idx = [i for i, dim in enumerate(self.strides) if dim.kind == 'static' and dim.value is not None]
# static_fixed_strides_values = [dim.value for dim in self.strides if dim.kind == 'static' and dim.value is not None]
raw_dtype = True
if self.dtype is not None:
raw_dtype = False
expected_dtype = self.dtype
def key_parser(name: str, target: Any):
if isinstance(target, torch.Tensor):
shape = tuple(target.shape)
strides = tuple(target.stride())
dtype = dt.dtype(target.dtype)
elif isinstance(target, tir.Buffer):
shape = tuple(target.shape)
strides = tuple(target.strides)
dtype = dt.dtype(target.dtype)
else:
raise TypeError(
f"Unsupported buffer argument type for argument `{name}`: expected a `torch.Tensor` or `tir.Buffer`, got {type(target)}"
)
if not raw_shapes:
assert len(shape) == shape_len
shape = tuple(shape[i] for i in static_shape_idx)
# shape_fixed = tuple(shape[i] for i in static_fixed_shape_idx)
# assert shape_fixed == static_fixed_shape_values, f"shape mismatch"
if not raw_strides:
assert len(strides) == strides_len
strides = tuple(strides[i] for i in strides_shape_idx)
# strides_fixed = tuple(strides[i] for i in static_fixed_strides_idx)
# assert strides_fixed == static_fixed_strides_values
if not raw_dtype:
dtype = dt.dtype(dtype)
if dtype != expected_dtype:
raise TypeError(
f"Tensor dtype mismatch for argument `{name}`, expected {expected_dtype}, got {dtype}"
)
return shape, strides, dtype
return key_parser
def parse_key(self, target: Any):
return self.get_key_parser()(target)
@staticmethod
def match_shape(shape: tuple[Value, ...], target_shape: tuple[int, ...], vt: ArgVarTable):
if shape is None:
return target_shape
args = []
for s, target in zip(shape, target_shape):
args.append(s.create_prim_func_arg(s.name, target, vt, create_arg=False))
return args
def create_prim_func_arg(self, name: str, value: Any, vt: ArgVarTable):
if isinstance(value, tir.Buffer):
shape = value.shape
strides = value.strides
dtype = value.dtype
elif isinstance(value, torch.Tensor):
shape = value.shape
strides = value.stride()
dtype = dt.dtype(value.dtype)
else:
raise TypeError(f"Unsupported buffer argument type: {type(value)}")
shape = self.match_shape(self.shape, shape, vt)
strides = self.match_shape(self.strides, strides, vt)
arg = buffer(shape, dtype=self.dtype or dtype, strides=strides, scope=self.scope)
return tb_tir.arg(name, arg)
def promote(self):
shape = _try_convert_static_shape(self.shape)
strides = _try_convert_static_shape(self.strides)
if shape is not None and strides is not None and self.dtype is not None:
buf = buffer(shape, self.dtype, strides=strides, scope=self.scope)
return TIRAnnot(data=buf)
# def __repr__(self):
# items = []
# if self.shape is not None:
# items.append(f'shape=[{', '.join(map(repr, self.shape))}]')
# if self.strides is not None:
# items.append(f'strides=[{', '.join(map(repr, self.strides))}]')
# if self.dtype is not None:
# items.append(f'dtype={self.dtype}')
# items.append(f'scope={repr(self.scope)}')
# return 'Buffer(' + ', '.join(items) + ')'
class TensorAnnot(BufferAnnot):
@staticmethod
def _construct_strides(shape: tuple[Any]):
s, strides = 1, [1]
for dim in shape[:0:-1]:
s *= dim
strides.append(s)
return tuple(reversed(strides))
def __call__(
self,
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
data=None,
strides=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
):
if isinstance(shape, (int, PrimExpr)):
shape = (shape,)
strides = strides or self._construct_strides(shape)
return super().__call__(
shape=shape,
dtype=dtype,
data=data,
strides=strides,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators)
def promote(self):
shape = _try_convert_static_shape(self.shape)
if shape is not None and self.dtype is not None:
strides = self._construct_strides(shape)
buf = buffer(shape, self.dtype, strides=strides, scope=self.scope)
return TIRAnnot(data=buf)
class StridedTensorAnnot(BufferAnnot):
def __call__(
self,
shape,
strides,
dtype: _DType = "float32",
data=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
):
return super().__call__(
shape=shape,
strides=strides,
dtype=dtype,
data=data,
elem_offset=elem_offset,
scope=scope,
align=align,
offset_factor=offset_factor,
buffer_type=buffer_type,
axis_separators=axis_separators,
)
def __getitem__(self, params):
shape, strides, dtype = params
shape = _canonicalize_shape(shape)
strides = _canonicalize_strides(strides)
dtype = _canonicalize_dtype(dtype)
return StridedTensorAnnot(shape, strides, dtype)
class FragmentBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'local.fragment'
class SharedBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'shared.dyn'
class LocalBufferAnnot(BufferAnnot):
@property
def scope(self):
return 'local'
class DynAnnot(Value):
'''
Dynamic variable annotation represents a tvm tir.Var argument
'''
def __call__(self, dtype: AnyDType = dt.float32, name: str | None = None) -> DynAnnot:
return tir.Var(name, dtype)
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
dtype = None
if len(params) == 1:
name, = params
if len(params) == 2:
dtype, name = params
dtype = _canonicalize_dtype(dtype) or dt.int32
return DynAnnot(kind='dynamic', dtype=dtype, name=name)
@dataclass
class DTypeAnnot(Annot):
'''
Data type annotation ensures automatically conversion from AnyDType to dtype
>>> def foo(A: T.dtype): print(A)
>>> foo(torch.float32)
dtype('float32')
>>> foo(T.float32)
dtype('float32')
>>> foo('float32')
dtype('float32')
'''
name: str | None = None
def is_kernel_arg(self) -> bool:
return False
def with_name(self, name):
return DTypeAnnot(name=name)
def get_key_parser(self):
return lambda name, value: (dt.dtype(value),)
def create_prim_func_arg(self, name, value, vt):
return dt.dtype(value)
def __repr__(self):
return self.name + '$dtype'
@dataclass
class TIRAnnot(Annot):
'''
TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments
>>> def foo(A: T.Buffer((128,), T.float32)): ...
'''
data: tir.Buffer | tir.Var
def is_kernel_arg(self) -> bool:
return True
def get_key_parser(self):
return lambda name, value: (None,)
def create_prim_func_arg(self, name, value, vt):
return tb_tir.arg(name, self.data)
def with_name(self, name: str):
IRBuilder.name(name, self.data)
return self
def __repr__(self):
return repr(self.data)
if TYPE_CHECKING:
class Buffer(Generic[_Shape, _DType]):
def __init__(
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
data=None,
strides=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Buffer[Callable[[Unpack[_Shapes]]], _DType]:
...
@property
def shape(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> tuple[Unpack[_Shapes]]:
...
@property
def dtype(self: Buffer[Callable[[Unpack[_Shapes]]], _DType]) -> dt.dtype[_DType]:
...
@property
def strides(self) -> tuple[tir.PrimExpr]:
...
def scope(self) -> Scope:
...
class Tensor(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
def __new__(
shape: tuple[Unpack[_Shapes]],
dtype: _DType = "float32",
data=None,
strides=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
class StridedTensor(Generic[_Shape, _Stride, _DType], Buffer[_Shape, _DType]):
def __new__(
shape: tuple[Unpack[_Shapes]],
strides=None,
dtype: _DType = "float32",
data=None,
elem_offset=None,
scope=None,
align=0,
offset_factor=0,
buffer_type="",
axis_separators=None,
) -> Tensor[Callable[[Unpack[_Shapes]]], _DType]:
...
class FragmentBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
pass
class LocalBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
pass
class SharedBuffer(Generic[_Shape, _DType], Buffer[_Shape, _DType]):
pass
class dyn(tir.Var):
def __new__(cls, dtype: _DType = "float32", name: str | None = None) -> dyn[_DType]:
...
@property
def dtype(self: dyn[_DType]) -> dt.dtype[_DType]:
...
else:
Buffer = BufferAnnot()
Tensor = TensorAnnot()
StridedTensor = StridedTensorAnnot()
FragmentBuffer = FragmentBufferAnnot()
SharedBuffer = SharedBufferAnnot()
LocalBuffer = LocalBufferAnnot()
dyn = DynAnnot()
@dataclass
class FuncAnnot:
sig: inspect.Signature
arg_names: list[str]
annots: dict[str, Annot]
arg_parser: dict[str, Callable[[Any], tuple[Any, ...]]]
ker_arg_names: list[str]
@classmethod
def from_sig_annots(cls, sig: inspect.Signature, func_annots: dict[str, Any]) -> FuncAnnot:
annots = {}
arg_parser = {}
ker_arg_names = []
for param in sig.parameters.values():
name = param.name
annot = func_annots.get(name, Value('static', name))
if not isinstance(annot, Annot):
if not isinstance(annot, type) and callable(annot):
annot = annot()
if annot is dt.dtype:
annot = DTypeAnnot(name=name)
elif isinstance(annot, (tir.Buffer, tir.Var)):
annot = TIRAnnot(data=annot)
else:
annot = Value(kind='static', name=name)
annot = annot.promote() or annot
annots[name] = annot.with_name(name)
if annot.is_kernel_arg():
ker_arg_names.append(name)
arg_parser[name] = annot.get_key_parser()
arg_names = list(sig.parameters.keys())
return FuncAnnot(sig, arg_names, annots, arg_parser, ker_arg_names)
def parse_key(self, *args, **kws):
'''
Parse arguments and generates the cache key for jit caching
'''
args = {name: arg for name, arg in zip(self.arg_names, args)}
arg_dict = dict(**args, **kws)
parsed = []
for name, value in arg_dict.items():
key = self.arg_parser[name](name, value)
parsed.append((name, key))
return tuple(sorted(parsed))
def convert_to_kernel_args(self, *args, **kws):
args = {name: arg for name, arg in zip(self.arg_names, args)}
arg_dict = dict(**args, **kws)
return [arg_dict[name] for name in self.ker_arg_names]
def create_argument(self, name: str, value: Any, vt: ArgVarTable):
'''
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
'''
return self.annots[name].create_prim_func_arg(name, value, vt)
def is_all_static(self):
'''
Check if all arguments are static (i.e., can be fully determined at compile time)
'''
return all(isinstance(annot, TIRAnnot) for annot in self.annots.values())
def get_all_static_args(self):
res = {}
for name, annot in self.annots.items():
if isinstance(annot, TIRAnnot):
res[name] = annot.data
return res
def get_compile_time_unknown_args(self):
res = []
for name, annot in self.annots.items():
if not isinstance(annot, TIRAnnot):
res.append(name)
return res
......@@ -6,12 +6,18 @@ import inspect
from tilelang.language.kernel import KernelLaunchFrame
from tvm_ffi.container import Map
from tvm.ir.base import Span
from tvm.ir.expr import Range
from tvm.tir.stmt import BufferRegion
from .ast import BaseBuilder, IRGenerator, eval_op, mutate
from .utils import construct_strides
import tvm
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from tvm.tir.expr import BufferLoad, EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, Any, Generic, TypeVar, ForwardRef, Union
from collections.abc import Sequence
from .annot import FuncAnnot, ArgVarTable, Annot
import pprint
# Python 3.9 compatibility for ParamSpec and Self
try:
from typing import ParamSpec, Self
......@@ -31,7 +37,9 @@ def unwrap_expr(expr) -> PrimExpr | int | float:
'''
if isinstance(expr, tir.meta_var):
expr = expr.value
elif isinstance(expr, Buffer) and expr.scope() == 'local.var':
elif isinstance(expr, Ref):
return expr.load()
elif is_var(expr):
expr = tir.BufferLoad(expr, indices=[0])
elif isinstance(expr, (EqualOp, NotEqualOp)):
expr = expr.asobject()
......@@ -113,6 +121,30 @@ class SerialForWithStep:
@dataclass
class OutTensor:
shape: Sequence[PrimExpr]
dtype: dt.dtype
@property
def strides(self):
return construct_strides(tuple(self.shape))
@dataclass
class Ref:
bufload: BufferLoad
@property
def buffer(self):
return self.bufload.buffer
def store(self, value):
tir.buffer_store(self.bufload.buffer, value, self.bufload.indices)
def load(self):
return self.bufload
class UnrollForWithStep(SerialForWithStep):
...
......@@ -145,11 +177,15 @@ def is_var(v: Any) -> bool:
class Builder(BaseBuilder):
def __init__(self):
def __init__(self, func_annot: FuncAnnot = None):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {}
self.arg_annotations = {}
self.macro_arg_annot = {}
self.func_annot = func_annot
self.out_idx = []
self.out_tensor_cnt = 0
self.arg_vt = ArgVarTable()
@classmethod
def current(cls) -> Self:
......@@ -162,6 +198,8 @@ class Builder(BaseBuilder):
with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name)
yield
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError('Not all tensor allocated from `T.empty` are returned')
@contextmanager
def macro(self, name=None, annotations=None):
......@@ -169,9 +207,9 @@ class Builder(BaseBuilder):
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame, self.arg_annotations
save = self.name_inside_frame, self.macro_arg_annot
self.name_inside_frame = {}
self.arg_annotations = annotations or {}
self.macro_arg_annot = annotations or {}
pos = len(self.frames)
# here we add a ExitedMacroFrame to preserve the frame stack inside macro
# because macro may bind some variable, and return it
......@@ -188,7 +226,7 @@ class Builder(BaseBuilder):
self.frames.append(MacroFrame())
yield
self.frames[pos] = ExitedMacroFrame()
self.name_inside_frame, self.arg_annotations = save
self.name_inside_frame, self.macro_arg_annot = save
def get(self):
return self.ir_builder.get()
......@@ -269,8 +307,11 @@ class Builder(BaseBuilder):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
elif not isinstance(val, tvm.tir.Buffer):
raise TypeError(f"Unsupported eval value: {val} of type {type(val)}")
elif isinstance(val, (Buffer, Var)):
pass
else:
logger.warning(
f"Unused return value: {val}({type(val)})", stack_info=True, stacklevel=2)
def ctx_for(self, it):
self.check_continue_break()
......@@ -355,10 +396,26 @@ class Builder(BaseBuilder):
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# ```
if isinstance(orig_value, Ref) and isinstance(value, (int, float, PrimExpr)):
orig_value.store(value)
return orig_value
if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)):
tir.buffer_store(orig_value, value, 0)
return orig_value
# 2. Quick return for trivil types
if isinstance(value, (tuple, list, tvm.ffi.Array, int, float, str)):
return value
if isinstance(value, tir.IntImm) and value.dtype == 'int32':
return value.value
if isinstance(value, (Var, Buffer)):
IRBuilder.name(name, value)
return value
# 3. Bind immutable tilelang objects
res = self.bind_immutable(name, value)
# 4. Check variable scope and shadowing
if name != '_':
frame = self.find_frame_idx(TIR_VAR_SCOPE_FRAME)
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
......@@ -372,6 +429,9 @@ class Builder(BaseBuilder):
return res
def unwrap_value(self, value):
'''
Unwrap some tilelang objects to get their inner value
'''
value = unwrap_expr(value)
# handle bx, by = tl.Kernel(128, 128), rval is frame
if isinstance(value, tir.frame.IRBuilderFrame):
......@@ -380,6 +440,10 @@ class Builder(BaseBuilder):
return value
def bind_immutable(self, name, value):
'''
Bind an immutable tilelang objects.
The immutability means the result is usually not changed or re-assigned in a python block.
'''
if name == '_':
# use _tmp to make the generated tir more readable
name = "_tmp"
......@@ -393,11 +457,19 @@ class Builder(BaseBuilder):
stacklevel=2,
)
return self.enter_frame(value)
elif isinstance(value, OutTensor):
arg = tir.arg(name,
tir.buffer(
shape=value.shape,
dtype=value.dtype,
strides=value.strides,
))
arg._out_idx = self.out_tensor_cnt
self.out_tensor_cnt += 1
return arg
elif isinstance(value, (Buffer, tir.IterVar, tir.Var)):
IRBuilder.name(name, value)
return value
elif isinstance(value, (tuple, list, tvm.ffi.Array)):
return value
else:
try:
value = tvm.runtime.convert(value)
......@@ -420,7 +492,10 @@ class Builder(BaseBuilder):
def aug_assign(self, op, target, aug_value):
self.check_continue_break()
if is_var(target):
if isinstance(target, Ref):
target.store(eval_op(op, target.bufload, aug_value))
return target
elif is_var(target):
tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
return target
elif isinstance(target, Buffer):
......@@ -457,10 +532,15 @@ class Builder(BaseBuilder):
else:
return super().ifexp(cond, then, otherwise)
def ret(self, value):
def ret(self, value=None):
self.check_continue_break()
# handle return T.alloc_var()
value = self.unwrap_value(value)
if value is None:
value = tuple()
elif isinstance(value, tuple):
value = tuple(self.unwrap_value(v) for v in value)
else:
value = self.unwrap_value(value)
last_macro = self.find_frame_idx(MacroFrame)
if last_macro is not None:
frame = self.find_frame_idx(TIR_CONTROL_FRAME, start=last_macro)
......@@ -478,7 +558,20 @@ class Builder(BaseBuilder):
" return a\n"
"```"
)
return value
return value
else:
if not isinstance(value, tuple):
value = (value,)
for v in value:
if not isinstance(v, Buffer) or not hasattr(v, '_out_idx'):
raise RuntimeError(
f'Only tensor allocated from `T.empty` can be returned in a prim_func, got {v}({type(v)})'
)
# convert 0, 1, 2 => -3, -2, -1 as the out tensor index
self.out_idx.append(v._out_idx - self.out_tensor_cnt)
if len(self.out_idx) != self.out_tensor_cnt:
raise RuntimeError(f'Not all tensor from `T.empty` are returned, only got {value}')
return NotImplemented
def ctx_with(self, ctx):
self.check_continue_break()
......@@ -487,9 +580,11 @@ class Builder(BaseBuilder):
else:
return super().ctx_with(ctx)
def assert_expr(self, cond, msg):
def assert_expr(self, cond, msg=None):
self.check_continue_break()
cond = unwrap_cond(cond)
if msg is None:
msg = 'Assertion failed'
if isinstance(cond, PrimExpr):
self.enter_frame(tir.Assert(cond, msg))
elif not cond:
......@@ -506,30 +601,41 @@ class Builder(BaseBuilder):
return self.unwrap_value(value)
def macro_arg(self, name, value):
from tilelang.language.proxy import Ref
annot_value = self.arg_annotations.get(name, None)
annot_value = self.macro_arg_annot.get(name, None)
if annot_value is Var or annot_value is Ref:
if annot_value is Var:
logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`')
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
if not is_var:
raise ValueError(
f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})'
)
return value.buffer
if isinstance(value, BufferLoad):
if is_var(value.buffer):
return value.buffer
idx = [self.bind('_', idx) for idx in value.indices]
# indices = self.bind(f'_', value.indices)
return Ref(BufferLoad(value.buffer, indices=idx))
if isinstance(value, BufferRegion):
region = [
Range(
self.bind('_', x.begin),
end=self.bind('_', x.end) if x.end is not None else None)
for x in value.region
]
return BufferRegion(value.buffer, region=region)
raise ValueError(
f'To pass as reference, argument `{name}` is expected to be a variable or a buffer region, but got {value}({type(value)})'
)
elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
def prim_func_arg(self, name, value):
if isinstance(value, (Buffer, Var)):
return tir.arg(name, value)
elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated')
else:
raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
return self.func_annot.create_argument(name, value, self.arg_vt)
# if isinstance(value, (Buffer, Var)):
# return tir.arg(name, value)
# elif value is self.empty:
# raise ValueError(f'Argument `{name}` is not annotated')
# else:
# raise TypeError(
# f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
......@@ -547,6 +653,39 @@ class Builder(BaseBuilder):
_P = ParamSpec('_P')
_T = TypeVar('_T')
@dataclass
class PrimFuncCreater(Generic[_P, _T]):
func_annot: FuncAnnot
ir_gen: IRGenerator[_P, _T]
orig_func: Callable[_P, _T]
@property
def annot(self) -> dict[str, Annot]:
return self.func_annot.annots
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> PrimFunc[_P, _T]:
builder = Builder(self.func_annot)
with builder.prim_func(self.orig_func.__name__):
self.ir_gen.gen(builder)(*args, **kwargs)
res: PrimFunc = builder.get()
res.ir_gen = self.ir_gen
res.orig_func = self.orig_func
res.func_annot = self.func_annot
res.out_idx_override = builder.out_idx or None
return res
def __repr__(self):
fmt = pprint.pformat(
{
'annot': self.func_annot.annots,
'ir_gen': self.ir_gen,
'orig_func': self.orig_func
},
indent=2)
return f'{self.__class__.__name__}(\n{fmt}\n)'
if TYPE_CHECKING:
class PrimFunc(Generic[_P, _T], tvm.tir.PrimFunc):
......@@ -557,8 +696,10 @@ if TYPE_CHECKING:
attrs: tvm.Attrs | None
span: Span | None
ir_gen: IRGenerator[_P, _T] | None
source: str | None
orig_func: Callable[_P, _T] | None
func_annot: FuncAnnot | None
out_idx_override: list[int] | None
else:
PrimFunc = tvm.tir.PrimFunc
......@@ -580,6 +721,12 @@ class Macro(Generic[_P, _T]):
res = self.ir_gen.gen(builder)(*args, **kwargs)
return res
def __hash__(self):
return id(self)
def __eq__(self, other):
return id(self) == id(other)
def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""
......@@ -683,13 +830,9 @@ def get_type_hints(func):
return hints
def _is_static_annot(annot: Any) -> bool:
return isinstance(annot, (dt.dtype, Buffer, Var))
def prim_func(func: Callable[_P, _T] = None,
*,
generator: bool = False) -> PrimFunc[_P, _T] | Callable[_P, PrimFunc[_P, _T]]:
generator: bool = False) -> PrimFunc[_P, _T] | PrimFuncCreater[_P, _T]:
"""
Decorator to create a primitive function (PrimFunc) for TileLang IR generation.
This decorator transforms a Python function into a TileLang primitive function by analyzing
......@@ -739,45 +882,21 @@ def prim_func(func: Callable[_P, _T] = None,
sig = inspect.signature(func)
annot = get_type_hints(func)
for k in annot:
if callable(annot[k]):
annot[k] = annot[k]()
# check whether all arguments are annotated
all_arg_annotated = all([x in annot for x in sig.parameters])
# check whether all annotations are Buffer/Var/dtype
all_annot_are_static = all([_is_static_annot(x) for x in annot.values()])
func_annot = FuncAnnot.from_sig_annots(sig, annot)
ir_gen = mutate(func)
def prim_func_generator(*args, **kwargs):
builder = Builder()
with builder.prim_func(func.__name__):
ir_gen.gen(builder)(*args, **kwargs)
res = builder.get()
res.ir_gen = ir_gen
res.source = ir_gen.source
res.orig_func = func
return res
prim_func_generator.ir_gen = ir_gen
prim_func_generator.source = ir_gen.source
prim_func_generator.orig_func = func
if generator:
return prim_func_generator
prim_func_generator = PrimFuncCreater(func_annot, ir_gen, orig_func=func)
if all_arg_annotated and all_annot_are_static:
return prim_func_generator(**annot)
if func_annot.is_all_static():
args = func_annot.get_all_static_args()
return prim_func_generator(**args)
else:
raise ValueError(
"Some arguments are not supported or statically annotated, \n"
"please check the annotations or set generator=True to get a prim_func generator.\n"
f"Argument Annotations: {annot}\n"
"Example usage of generator:\n"
"```py\n"
"@prim_func(generator=True)\n"
"def my_func(a=T.Tensor((128,), T.float32)): ...\n"
"return my_func()\n"
"```")
if generator is False:
unknown_args = func_annot.get_compile_time_unknown_args()
raise ValueError(
f"Cannot create PrimFunc for `{func.__name__}`, some arguments are not compile-time known, \n"
f"Annotations:\n{func_annot.annots}"
f"Unknown Args: {unknown_args}")
return prim_func_generator
return impl(func) if func is not None else impl
from tilelang import tvm
from tvm import ir
import torch
from typing import TYPE_CHECKING, Union
from typing import Generic, TypeVar, Union, TYPE_CHECKING
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np
dtype = tvm.DataType
_T = TypeVar('_T')
if TYPE_CHECKING:
class dtype(Generic[_T]):
def torch(self) -> torch.dtype:
...
else:
dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
......
......@@ -4,6 +4,7 @@ import inspect
from typing import Any, Callable, Literal
from tilelang import env
from hashlib import sha256
from tvm import tir
import linecache
......@@ -84,3 +85,17 @@ def get_compiled_object(source: str | ast.AST,
locs = {}
exec(compiled, globals, locs)
return locs[name]
def construct_strides(shape: tuple[Any, ...], allow_prim_expr: bool = True) -> tuple[Any, ...]:
"""Construct row-major strides from shape."""
strides = []
stride = 1
for s in shape[::-1]:
strides.append(stride)
stride *= s
if not allow_prim_expr and isinstance(stride, tir.PrimExpr):
raise ValueError(
"Cannot construct strides with PrimExpr when allow_prim_expr is False.")
strides = tuple(reversed(strides))
return strides
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment