__init__.py 8.07 KB
Newer Older
1
"""The auto-tune module for tilelang programs."""
2

3
import tilelang
4
5
from tilelang import tvm as tvm
import inspect
yyttt6's avatar
yyttt6 committed
6
7
from functools import wraps, partial
from typing import Callable, List, Literal, Any
8
9
10
11
from tqdm import tqdm
import logging
from dataclasses import dataclass
import concurrent.futures
12
import os
13

14
15
logger = logging.getLogger(__name__)

16
logging.basicConfig(
17
    filename='autotuner.log',
18
    filemode='w',
19
    level=logging.DEBUG,
20
21
22
23
24
25
    format='%(asctime)s %(levelname)s:%(message)s')


@dataclass(frozen=True)
class JITContext:
    out_idx: List[int]
26
    supply_type: tilelang.TensorSupplyType
27
28
29
    ref_prog: Callable
    rtol: float
    atol: float
30
    max_mismatched_ratio: float
31
    skip_check: bool
32
    profiler: tilelang.Profiler
33
34
35
    target: Literal['cuda', 'hip']


yyttt6's avatar
yyttt6 committed
36
37
38
39
40
41
42
43
44
45
46
@dataclass(frozen=True)
class AutotuneResult:
    latency: float
    config: dict
    ref_latency: float
    libcode: str
    func: Callable
    kernel: Callable


class AutoTuner:
47

yyttt6's avatar
yyttt6 committed
48
    def __init__(self, fn: Callable, configs):
49
50
51
52
53
54
        self.fn = fn
        self.configs = configs
        self.ref_latency_cache = None
        self.jit_input_tensors = None
        self.ref_input_tensors = None

yyttt6's avatar
yyttt6 committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    @classmethod
    def from_kernel(cls, kernel: Callable, configs):
        return cls(kernel, configs)

    def set_compile_args(self,
                         out_idx: List[int],
                         supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal,
                         ref_prog: Callable = None,
                         rtol: float = 1e-2,
                         atol: float = 1e-2,
                         max_mismatched_ratio: float = 0.01,
                         skip_check: bool = False,
                         target: Literal['auto', 'cuda', 'hip'] = 'auto'):

        def _compile(*config_arg):
            kernel = tilelang.compile(self.fn(*config_arg), out_idx=out_idx, target=target)
            profiler = kernel.get_profiler()
            jit_context = JITContext(
                out_idx=out_idx,
                supply_type=supply_type,
                ref_prog=ref_prog,
                rtol=rtol,
                atol=atol,
                max_mismatched_ratio=max_mismatched_ratio,
                skip_check=skip_check,
                profiler=profiler,
                target=target)
            return jit_context

        self.jit_compile = _compile
        return self
86

yyttt6's avatar
yyttt6 committed
87
    def run(self, warmup: int = 25, rep: int = 100, timeout: int = 100):
88
        sig = inspect.signature(self.fn)
yyttt6's avatar
yyttt6 committed
89
90
        keys = list(sig.parameters.keys())
        bound_args = sig.bind()
91
92
93
        bound_args.apply_defaults()
        best_latency = 1e8
        best_config = None
yyttt6's avatar
yyttt6 committed
94
        best_jit_context = None
95

96
        def target_fn(jit_context):
97
98
99
100
101
102
            # Unpack the context
            profiler = jit_context.profiler
            skip_check = jit_context.skip_check
            ref_prog = jit_context.ref_prog
            rtol = jit_context.rtol
            atol = jit_context.atol
103
            max_mismatched_ratio = jit_context.max_mismatched_ratio
104

105
            self.jit_input_tensors = profiler._get_inputs(
106
107
108
109
                with_output=profiler ==
                "tvm") if self.jit_input_tensors is None else self.jit_input_tensors

            if (not skip_check) and (ref_prog is not None):
110
                profiler.assert_allclose(
111
                    ref_prog, rtol=rtol, atol=atol, max_mismatched_ratio=max_mismatched_ratio)
112

113
            latency = profiler.do_bench(
yyttt6's avatar
yyttt6 committed
114
                profiler.func, n_warmup=warmup, n_repeat=rep, input_tensors=self.jit_input_tensors)
115
            if self.ref_latency_cache is None and ref_prog is not None:
116
                self.ref_input_tensors = profiler._get_inputs(
117
                    with_output=False) if self.ref_input_tensors is None else self.ref_input_tensors
118
                self.ref_latency_cache = profiler.do_bench(
yyttt6's avatar
yyttt6 committed
119
                    ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
120
121
122

            return latency, self.ref_latency_cache

123
124
        config_args = []
        for config in self.configs:
125
126
            new_args = []
            for name, value in bound_args.arguments.items():
yyttt6's avatar
yyttt6 committed
127
                if name not in keys:
128
129
130
131
                    new_args.append(value)
                else:
                    new_args.append(config[name])
            new_args = tuple(new_args)
132
133
134
135
            config_args.append(new_args)

        num_workers = max(1, int(os.cpu_count() * 0.9))
        pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
136
        futures = []
yyttt6's avatar
yyttt6 committed
137
        future_to_index = {}
138
        for i, config_arg in enumerate(config_args):
yyttt6's avatar
yyttt6 committed
139
140
141
142
            future = pool.submit(
                self.jit_compile,
                *config_arg,
            )
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            futures.append(future)
            future_to_index[future] = i

        results_with_configs = []
        for future in tqdm(
                concurrent.futures.as_completed(futures),
                total=len(futures),
                desc="Compiling configurations"):
            idx = future_to_index[future]
            config = config_args[idx]
            try:
                result = future.result()
                results_with_configs.append((result, config))
            except Exception:
                logger.debug(f"Compilation failed for config {config} at index {idx}")
                continue
159
160

        ref_latency = None
161
        progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
162
        for i in progress_bar:
163
            jit_context, config = results_with_configs[i]
164
            try:
165
166
167
                # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
                # Because tma init may behave strangely with one thread
                latency, ref_latency = target_fn(jit_context)
168
            except Exception as e:
169
170
171
172
                logger.info(
                    f"An error occurred while testing config {config}, checkout autotuner.log for more details"
                )
                logger.debug(f"Error: {e}")
173
174
                continue

175
            logging.debug(f"Config {config} latency: {latency} at index {i}")
176
177
178
179

            if latency < best_latency:
                best_latency = latency
                best_config = config
yyttt6's avatar
yyttt6 committed
180
                best_jit_context = jit_context
181
182
183
184
185

            progress_bar.set_postfix({"best_latency": best_latency})
            tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")

        pool.shutdown()
yyttt6's avatar
yyttt6 committed
186
187
188
189
190
191
192
        return AutotuneResult(
            latency=best_latency,
            config=best_config,
            ref_latency=ref_latency,
            libcode=best_jit_context.profiler.func.lib_code,
            func=self.fn(*best_config),
            kernel=best_jit_context.profiler.func)
193

yyttt6's avatar
yyttt6 committed
194
195
    def __call__(self) -> Any:
        return self.run()
196
197


yyttt6's avatar
yyttt6 committed
198
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> Callable:
199
    """
200
    Decorator for tilelang program
201
202
    """

yyttt6's avatar
yyttt6 committed
203
204
205
206
207
    def decorator(fn: Callable) -> AutoTuner:
        autotuner = AutoTuner(fn, configs=configs)
        autotuner.jit_compile = fn
        autotuner.run = partial(autotuner.run, warmup, rep, timeout)
        return autotuner
208
209
210
211
212

    return decorator


def jit(out_idx: List[int],
213
        supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal,
214
215
216
        ref_prog: Callable = None,
        rtol: float = 1e-2,
        atol: float = 1e-2,
217
        max_mismatched_ratio: float = 0.01,
218
219
220
221
222
223
224
        skip_check: bool = False,
        target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:

    def wrapper(fn: Callable):

        @wraps(fn)
        def decorator(*args, **kwargs) -> float:
225
226
227

            kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target)

228
            profiler = kernel.get_profiler()
229
230
231
232
233
234
235

            return JITContext(
                out_idx=out_idx,
                supply_type=supply_type,
                ref_prog=ref_prog,
                rtol=rtol,
                atol=atol,
236
                max_mismatched_ratio=max_mismatched_ratio,
237
238
239
240
241
242
243
                skip_check=skip_check,
                profiler=profiler,
                target=target)

        return decorator

    return wrapper