".github/vscode:/vscode.git/clone" did not exist on "396a13e6ad6b62f850aac026e4ddc57134e5f4e7"
lazyjit.zh.ipynb 20.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e0deecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
12
    "\n",
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n",
    "import tilelang\n",
    "import torch\n",
    "import tilelang.language as T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ca2c56d",
   "metadata": {},
   "source": [
    "# Tilelang Lazy JIT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "156e7370",
   "metadata": {},
   "source": [
    "## Tensor Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b070c109",
   "metadata": {},
   "source": [
    "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n",
    "\n",
    "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n",
    "\n",
    "例如,下面的代码用 T.Tensor[[int, int], T.float16] 来标注了一个二维 Tensor\n",
    "1. 它的每个维度都是编译期常量,如果改变,会触发重新编译\n",
    "2. 它的类型必须是 T.float16\n",
    "\n",
    "DType 除了写确定的外,还可以写 Any 或者 None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "60bf8954",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm(\n",
    "    A: T.Tensor[[int, int], T.float16],\n",
    "    B: T.Tensor[[int, int], T.float16],\n",
    "    out_dtype: T.dtype = T.float32,\n",
    "    block_M: int = 128,\n",
    "    block_N: int = 128,\n",
65
    "    block_K: int = 32,\n",
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    "):\n",
    "    M, K = A.shape\n",
    "    K, N = B.shape\n",
    "    C = T.empty((M, N), out_dtype)\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
    "        B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
    "            T.copy(A[bx * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[k * block_K, by * block_N], B_shared)\n",
    "            T.gemm(A_shared, B_shared, C_local)\n",
    "        T.copy(C_local, C[bx * block_M, by * block_N])\n",
    "    return C"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28f868fe",
   "metadata": {},
   "source": [
    "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee13394a",
   "metadata": {},
   "outputs": [],
   "source": [
98
99
    "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    "C = gemm(A, B)\n",
    "\n",
    "# check output is correct\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6705091",
   "metadata": {},
   "source": [
    "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d8aab5b7",
   "metadata": {},
   "outputs": [],
   "source": [
122
123
    "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
    "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n",
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    "C = gemm(A, B, block_M=64, block_N=64)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce6b7391",
   "metadata": {},
   "source": [
    "你也可以手动调用 compile 函数编译 kernel\n",
    "\n",
    "1. `ker.compile` 编译 kernel\n",
    "2. `ker.get_tir` 获取 tir\n",
    "3. `ker.par_compile` 并行编译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f3cf3a2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2025-11-25 17:29:46  [TileLang:tilelang.cache.kernel_cache:WARNING]: Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.\n"
     ]
    }
   ],
   "source": [
    "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n",
    "C = kernel(A, B)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "921761b5",
   "metadata": {},
   "source": [
    "## More Tensor Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4539e54e",
   "metadata": {},
   "source": [
    "### 用 macro 来分离实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad96ba65",
   "metadata": {},
   "source": [
    "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "171d4fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.macro\n",
    "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n",
    "        B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
    "            T.copy(A[bx * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[k * block_K, by * block_N], B_shared)\n",
    "            T.gemm(A_shared, B_shared, C_local)\n",
    "        T.copy(C_local, C[bx * block_M, by * block_N])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "446a1acd",
   "metadata": {},
   "source": [
    "### 用 T.dyn 标记动态 Shape\n",
    "\n",
    "当某些维度是动态的的时候,可以用 T.dyn 来标记。T.dyn 可以接受一个字符串参数,表示变量的名字"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a38aa95",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_dyn_K(\n",
222
223
    "    A: T.Tensor[[int, T.dyn[\"K\"]], T.float16],  # noqa: F821\n",
    "    B: T.Tensor[[T.dyn[\"K\"], int], T.float16],  # noqa: F821\n",
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    "):\n",
    "    M, K = A.shape\n",
    "    K, N = B.shape\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c60fd346",
   "metadata": {},
   "source": [
    "查看 lazy_jit 的函数签名,其中带有后缀`$` 的是不确定的编译期常量,带有 `$dyn` 的是运行时的变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c6992eb4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'A': TensorAnnot(shape=[A_shape_0$, K$dyn], strides=None, dtype=dtype('float16')),\n",
       " 'B': TensorAnnot(shape=[K$dyn, B_shape_1$], strides=None, dtype=dtype('float16'))}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gemm_dyn_K.func.annot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fe6cfdc8",
   "metadata": {},
   "outputs": [],
   "source": [
269
270
    "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    "C = gemm_dyn_K(A, B)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ee97bf7",
   "metadata": {},
   "source": [
    "### 用 T.StridedTensor 标记带 stride 的 Tensor\n",
    "\n",
    "标记方法:T.StridedTensor[Shape, Stride, DType],每个 Shape 或 Stride 可以写\n",
    "* int: 表示编译期常量\n",
    "* T.dyn:表示运行时常量\n",
    "\n",
    "DType 可以写 None 或 Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9dde1dae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any\n",
    "\n",
299
    "\n",
300
    "@tilelang.lazy_jit\n",
301
    "def as_contingious(A: T.StridedTensor[[T.dyn, T.dyn], [T.dyn, T.dyn], Any]):\n",
302
303
304
305
306
307
    "    M, N = A.shape\n",
    "    B = T.empty((M, N), A.dtype)\n",
    "    block_M = 128\n",
    "    block_N = 128\n",
    "    with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n",
    "        T.copy(\n",
308
309
    "            A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
    "            B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n",
310
311
312
313
314
315
316
317
318
319
320
    "        )\n",
    "    return B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dec2c0a7",
   "metadata": {},
   "outputs": [],
   "source": [
321
    "A = torch.randn(1024, 1024, device=\"cuda\")\n",
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    "B = as_contingious(A[::2, ::2])\n",
    "B_ref = A[::2, ::2].contiguous()\n",
    "torch.testing.assert_close(B, B_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5fb20d6",
   "metadata": {},
   "source": [
    "## More Annotation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "890df0a2",
   "metadata": {},
   "source": [
    "### 用 T.ptr 标注 Tensor\n",
    "lazy_jit 允许你用 T.ptr 来声明一个 handle,但必须在函数内用 T.match_buffer 给它定义 shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0fc17af6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_ptr(\n",
    "    A: T.ptr,\n",
    "    B: T.ptr,\n",
    "    M: int,\n",
    "    N: int,\n",
    "    K: int,\n",
    "):\n",
    "    A = T.match_buffer(A, (M, K), T.float16)\n",
    "    B = T.match_buffer(B, (K, N), T.float16)\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8e52a554",
   "metadata": {},
   "outputs": [],
   "source": [
373
374
    "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    "C = gemm_ptr(A, B, 1024, 256, 512)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b19ef90",
   "metadata": {},
   "source": [
    "### 用 T.int32 标注运行时变量\n",
    "\n",
    "lazy_jit 允许你用 T.int32 或其他类型来定义运行时变量,这样,你可以写一个完全动态的 gemm,这和 triton 非常相似"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1e7598a",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def gemm_ptr_dyn(\n",
    "    A: T.ptr,\n",
    "    B: T.ptr,\n",
    "    M: T.int32,\n",
    "    N: T.int32,\n",
    "    K: T.int32,\n",
    "):\n",
    "    A = T.match_buffer(A, (M, K), T.float16, strides=(K, 1))\n",
    "    B = T.match_buffer(B, (K, N), T.float16, strides=(N, 1))\n",
    "    C = T.empty((M, N), T.float32)\n",
    "    gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n",
    "    return C"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9e9a4c88",
   "metadata": {},
   "outputs": [],
   "source": [
419
420
    "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n",
    "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n",
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n",
    "C_ref = (A @ B).float()\n",
    "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39166cb4",
   "metadata": {},
   "source": [
    "## 编译与并行编译"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c6fbe08",
   "metadata": {},
   "source": [
    "lazyjit 和原来的 jit 都支持并行编译\n",
    "\n",
    "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7222e57b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c6d7f05cdfff412e9a527332438f7aa2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Elaborating:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14836065a21b41ae8fc34e8763ae49fc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parallel Compiling:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "[<tilelang.jit.kernel.JITKernel at 0x7f29c0072ed0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00882f0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00735f0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0088890>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c01f94c0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0073fe0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c0070ce0>,\n",
       " <tilelang.jit.kernel.JITKernel at 0x7f29c00732f0>]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from itertools import product\n",
    "\n",
499
    "\n",
500
501
502
    "def get_configs():\n",
    "    return [\n",
    "        {\n",
503
504
505
506
507
    "            \"A\": T.Tensor((1024, 1024), T.float32),\n",
    "            \"B\": T.Tensor((1024, 1024), T.float32),\n",
    "            \"block_M\": block_M,\n",
    "            \"block_N\": block_N,\n",
    "            \"block_K\": block_K,\n",
508
509
510
511
    "        }\n",
    "        for block_M, block_N, block_K in product([32, 64], repeat=3)\n",
    "    ]\n",
    "\n",
512
    "\n",
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    "gemm.par_compile(get_configs())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5160d2cc",
   "metadata": {},
   "source": [
    "## 更便利的 Macro"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be44afc4",
   "metadata": {},
   "source": [
    "tilelang 的 macro 现在已经升级:\n",
    "\n",
    "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n",
    "2. 允许返回多个值\n",
    "3. 允许嵌套,递归"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79575972",
   "metadata": {},
   "source": [
    "### T.Ref 传递引用\n",
    "\n",
    "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90eaa6e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# from tvm.script import tir as T\n",
       "\n",
       "@T.prim_func\n",
       "def foo(x_handle: T.handle):\n",
       "    x = T.match_buffer(x_handle, (2,), strides=(1,))\n",
       "    # with T.block(\"root\"):\n",
       "    bx = T.launch_thread(\"blockIdx.x\", 1)\n",
       "    tx = T.launch_thread(\"threadIdx.x\", 128)\n",
       "    ty = T.launch_thread(\"threadIdx.y\", 1)\n",
       "    tz = T.launch_thread(\"threadIdx.z\", 1)\n",
       "    with T.block(\"tilelang_root\"):\n",
       "        T.reads()\n",
       "        idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n",
       "        T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n",
       "        T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n",
       "        idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n",
       "        x[1] = T.float32(1.0)\n",
       "        _tmp: T.int32 = idx[0]\n",
       "        x[_tmp] = T.float32(1.0)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@T.macro\n",
    "def macro_with_ref(x: T.Ref):\n",
584
585
    "    x = 1  # noqa: F841\n",
    "\n",
586
587
588
589
590
591
592
593
594
595
596
    "\n",
    "@T.prim_func\n",
    "def foo(x: T.Tensor((2,))):\n",
    "    with T.Kernel(1) as _:\n",
    "        # 支持常量 index\n",
    "        macro_with_ref(x[1])\n",
    "\n",
    "        # 也支持变量 index\n",
    "        idx = T.alloc_var(T.int32, 0)\n",
    "        macro_with_ref(x[idx])\n",
    "\n",
597
    "\n",
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    "foo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bb447a2",
   "metadata": {},
   "source": [
    "### 当作参数传递\n",
    "\n",
    "你可以把 macro 当做参数传递"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "dc7bb779",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tilelang.lazy_jit\n",
    "def element_wise(\n",
    "    A: T.Tensor[[T.dyn], Any],\n",
    "    fn,\n",
    "):\n",
623
    "    (N,) = A.shape\n",
624
625
626
627
628
629
630
    "    B = T.empty((N,), dtype=A.dtype)\n",
    "    block_N = 128\n",
    "    with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n",
    "        for i in T.Parallel(block_N):\n",
    "            idx = bx * block_N + i\n",
    "            B[idx] = fn(A[idx])\n",
    "    return B\n",
631
632
    "\n",
    "\n",
633
634
635
636
637
638
639
640
641
642
643
644
    "@T.macro\n",
    "def add_one(x):\n",
    "    return x + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a89fdb44",
   "metadata": {},
   "outputs": [],
   "source": [
645
    "A = torch.randn(1024, device=\"cuda\")\n",
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    "B = element_wise(A, add_one)\n",
    "B_ref = A + 1\n",
    "torch.testing.assert_close(B, B_ref)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef6e403a",
   "metadata": {},
   "source": [
    "### Macro 递归\n",
    "\n",
    "虽然不知道有没有这种需求,但 macro 是可以递归的,但要求终止条件编译期间确定"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7703cab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.macro\n",
    "def n31(x, var: T.Ref):\n",
    "    if x == 1:\n",
    "        pass\n",
    "    elif x % 2 == 0:\n",
    "        var = var // 2\n",
    "        n31(x // 2, var)\n",
    "    else:\n",
    "        var = var * 3 + 1\n",
    "        n31(x * 3 + 1, var)\n",
    "\n",
679
    "\n",
680
681
682
    "@tilelang.lazy_jit\n",
    "def foo(A: T.Tensor[[1], T.int32], n: int):\n",
    "    with T.Kernel(1) as _:\n",
683
    "        n31(n, A[0])"
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "542ddd4e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([18], device='cuda:0', dtype=torch.int32)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
704
    "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n",
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
    "foo(A, 5)\n",
    "A"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc30c2d2",
   "metadata": {},
   "source": [
    "### Macro 返回多个值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5a2388f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# from tvm.script import tir as T\n",
       "\n",
       "@T.prim_func\n",
       "def foo():\n",
       "    # with T.block(\"root\"):\n",
       "    x = T.launch_thread(\"blockIdx.x\", 32)\n",
       "    tx = T.launch_thread(\"threadIdx.x\", 128)\n",
       "    ty = T.launch_thread(\"threadIdx.y\", 1)\n",
       "    tz = T.launch_thread(\"threadIdx.z\", 1)\n",
       "    with T.block(\"tilelang_root\"):\n",
       "        T.reads()\n",
       "        T.writes()\n",
       "        s: T.int32 = T.sin(x)\n",
       "        c: T.int32 = T.cos(x)\n",
       "        a: T.int32 = s + c\n",
       "        b: T.int32 = s - c\n",
       "        T.evaluate(0)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@T.macro\n",
    "def sincos(x):\n",
    "    return T.sin(x), T.cos(x)\n",
    "\n",
755
    "\n",
756
757
758
759
    "@T.prim_func\n",
    "def foo():\n",
    "    with T.Kernel(32) as x:\n",
    "        s, c = sincos(x)\n",
760
761
762
763
    "        a = s + c  # noqa: F841\n",
    "        b = s - c  # noqa: F841\n",
    "\n",
    "\n",
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    "foo"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tilelang-dev_0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}