auto_tuning.md 4.27 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
Auto-Tuning Techniques for Performance Optimization
===================================================
<div style="text-align: left;">
<em>Author:</em> <a href="https://github.com/yyttt6">yyttt6</a>
</div>

## Overview

Auto-tuning a Tile Language program involves three main steps:

1. Implement the target program using Tile Language with reserved optimization parameters
2. ​Provide candidate configurations through manual search or [auto-generation using Carver](#using-carver-to-auto-generate-candidate-configurations)
3. Parallel compile and benchmark candidate configurations to identify the best performance

## Matrix Multiplication Example

The following example demonstrates auto-tuning matrix multiplication. Code has been simplified for readability - see `examples/gemm/example_gemm.py` for complete implementation. 

### Step 1: Implement with Reserved Parameters
Users can implement matrix multiplication in Tile Language while reserving parameters for optimization:
```python
# Reserved parameters for optimization
def kernel(
    block_M=None,
    block_N=None,
    block_K=None,
    num_stages=None,
    thread_num=None,
    enable_rasteration=None,
):
    dtype = "float16"
    accum_dtype = "float"

    # Matrix multiplication implementation
    @T.prim_func
    def main(
            A: T.Buffer((M, K), dtype),
            B: T.Buffer((N, K), dtype),
            C: T.Buffer((M, N), dtype),
    ):
        # ...existing code...

    return main
```
### Step 2: Generate Candidate Configurations
Manually define configurations or use combinatorial generation:
```python
configs = [
    {
        "block_M": 128,
        "block_N": 128,
        "block_K": 128,
        "num_stages": 3,
        "thread_num": 128,
        "enable_rasteration": True
    },
        {
        "block_M": 32,
        "block_N": 32,
        "block_K": 32,
        "num_stages": 0,
        "thread_num": 32,
        "enable_rasteration": False
    },
    # ...additional configurations...
]
```
It can also be given by combinatorial traversal of different parameters
```python
import itertools

block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
    itertools.product(
        block_M,
        block_N,
        block_K,
        num_stages,
        thread_num,
        enable_rasterization,
    ))

configs = [
    {
        "block_M": c[0],
        "block_N": c[1],
        "block_K": c[2],
        "num_stages": c[3],
        "thread_num": c[4],
        "enable_rasteration": c[5]
    } for c in _configs
]
```
### Step 3: Compile and Benchmark
Configure JIT compilation and benchmarking settings:
```python
autotuner = AutoTuner.from_kernel(
    kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
        out_idx=[-1],
        supply_type=tl.TensorSupplyType.Integer,
        ref_prog=ref_program,
        skip_check=False,
        target="auto",
    )
result = autotuner.run(warmup=3, rep=20)
out_c = result.kernel(a, b)
```
The result object contains optimized kernel implementation which can be used by users directly

## Using Carver to Auto-Generate Candidate Configurations

Carver is a lightweight framework for generating and ranking tile configurations (also known as tiling strategies, blocking schemes, or scheduling hints) for common GPU, CPU, and accelerator backends. It helps you explore efficient mappings of loops for operations such as matrix multiplication, elementwise transforms, and other reduction-oriented kernels.

or common operators, Carver provides pre-built templates (e.g., `MatmulTemplate`):

```python
# Configure Matmul template
arch = CUDA("cuda")
carve_template = MatmulTemplate(
    M=M,
    N=N,
    K=K,
    in_dtype="float16",
    out_dtype="float16",
    accum_dtype="float",
).with_arch(arch)

# Generate top-k optimization hints (topk=10 recommended)
roller_hints = carve_template.recommend_hints(topk=10)

# Configure candidate parameters
for hint in roller_hints:

    # ...existing code...

    config["block_M"] = block_m
    config["block_N"] = block_n
    config["block_K"] = hint.rstep[0]
    config["num_stages"] = hint.pipeline_stage
    config["thread_num"] = block_rows * block_cols * 32
    config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization

```