__init__.py 19.3 KB
Newer Older
1
2
3
4
5
"""The auto-tune module for tilelang programs.

This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
6

7
import tilelang
8
9
from tilelang import tvm as tvm
import inspect
yyttt6's avatar
yyttt6 committed
10
from functools import wraps, partial
11
from typing import Callable, List, Literal, Any, Optional, Union
12
13
14
15
from tqdm import tqdm
import logging
from dataclasses import dataclass
import concurrent.futures
16
import torch
17
import os
18
import sys
19

20
21
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
22
logger = logging.getLogger(__name__)
23
24
logger.setLevel(logging.DEBUG)
logger.propagate = False
25

26
27
28
29
30
31
32
33
34
35
36
37
formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')

file_handler = logging.FileHandler('autotuner.log', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(console_handler)
38
39
40
41


@dataclass(frozen=True)
class JITContext:
42
43
44
45
46
    """Context object for Just-In-Time compilation settings.

    Attributes:
        out_idx: List of output tensor indices.
        ref_prog: Reference program for correctness validation.
47
        supply_prog: Supply program for input tensors.
48
49
50
51
        rtol: Relative tolerance for output validation.
        atol: Absolute tolerance for output validation.
        max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
        skip_check: Whether to skip validation checks.
52
        cache_input_tensors: Whether to cache input tensors for each compilation.
53
54
        kernel: JITKernel instance for performance measurement.
        supply_type: Type of tensor supply mechanism.
55
56
        target: Target platform ('cuda' or 'hip').
    """
57
58
    out_idx: List[int]
    ref_prog: Callable
59
    supply_prog: Callable
60
61
    rtol: float
    atol: float
62
    max_mismatched_ratio: float
63
    skip_check: bool
64
    cache_input_tensors: bool
65
66
    kernel: tilelang.JITKernel
    supply_type: tilelang.TensorSupplyType
67
68
69
    target: Literal['cuda', 'hip']


yyttt6's avatar
yyttt6 committed
70
71
@dataclass(frozen=True)
class AutotuneResult:
72
73
74
75
76
77
78
79
80
81
    """Results from auto-tuning process.

    Attributes:
        latency: Best achieved execution latency.
        config: Configuration that produced the best result.
        ref_latency: Reference implementation latency.
        libcode: Generated library code.
        func: Optimized function.
        kernel: Compiled kernel function.
    """
yyttt6's avatar
yyttt6 committed
82
83
84
85
86
87
88
89
    latency: float
    config: dict
    ref_latency: float
    libcode: str
    func: Callable
    kernel: Callable


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
@dataclass(frozen=True)
class CompileArgs:
    """Compile arguments for the auto-tuner.

    Attributes:
        out_idx: List of output tensor indices.
        supply_type: Type of tensor supply mechanism.
        ref_prog: Reference program for correctness validation.
        supply_prog: Supply program for input tensors.
        out_idx: Union[List[int], int] = -1
        supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
        ref_prog: Callable = None
        supply_prog: Callable = None
        rtol: float = 1e-2
        atol: float = 1e-2
        max_mismatched_ratio: float = 0.01
        skip_check: bool = False
        cache_input_tensors: bool = True
        target: Literal['auto', 'cuda', 'hip'] = 'auto'
    """

    out_idx: Union[List[int], int] = -1
    supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
    ref_prog: Callable = None
    supply_prog: Callable = None
    rtol: float = 1e-2
    atol: float = 1e-2
    max_mismatched_ratio: float = 0.01
    skip_check: bool = False
    cache_input_tensors: bool = True
    target: Literal['auto', 'cuda', 'hip'] = 'auto'


yyttt6's avatar
yyttt6 committed
123
class AutoTuner:
124
125
126
127
128
129
130
131
132
    """Auto-tuner for tilelang programs.

    This class handles the auto-tuning process by testing different configurations
    and finding the optimal parameters for program execution.

    Args:
        fn: The function to be auto-tuned.
        configs: List of configurations to try during auto-tuning.
    """
133

yyttt6's avatar
yyttt6 committed
134
    def __init__(self, fn: Callable, configs):
135
136
137
138
139
        self.fn = fn
        self.configs = configs
        self.ref_latency_cache = None
        self.jit_input_tensors = None
        self.ref_input_tensors = None
140
141
        self.jit_compile = None
        self.compile_args = CompileArgs()
142

yyttt6's avatar
yyttt6 committed
143
144
    @classmethod
    def from_kernel(cls, kernel: Callable, configs):
145
146
147
148
149
150
151
152
153
        """Create an AutoTuner instance from a kernel function.

        Args:
            kernel: The kernel function to auto-tune.
            configs: List of configurations to try.

        Returns:
            AutoTuner: A new AutoTuner instance.
        """
yyttt6's avatar
yyttt6 committed
154
155
156
        return cls(kernel, configs)

    def set_compile_args(self,
157
                         out_idx: Union[List[int], int] = -1,
158
                         supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
yyttt6's avatar
yyttt6 committed
159
                         ref_prog: Callable = None,
160
                         supply_prog: Callable = None,
yyttt6's avatar
yyttt6 committed
161
162
163
164
                         rtol: float = 1e-2,
                         atol: float = 1e-2,
                         max_mismatched_ratio: float = 0.01,
                         skip_check: bool = False,
165
                         cache_input_tensors: bool = True,
yyttt6's avatar
yyttt6 committed
166
                         target: Literal['auto', 'cuda', 'hip'] = 'auto'):
167
168
169
170
        """Set compilation arguments for the auto-tuner.

        Args:
            out_idx: List of output tensor indices.
171
            supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
172
            ref_prog: Reference program for validation.
173
            supply_prog: Supply program for input tensors.
174
175
176
177
            rtol: Relative tolerance for validation.
            atol: Absolute tolerance for validation.
            max_mismatched_ratio: Maximum allowed mismatch ratio.
            skip_check: Whether to skip validation.
178
            cache_input_tensors: Whether to cache input tensors.
179
180
181
182
183
            target: Target platform.

        Returns:
            AutoTuner: Self for method chaining.
        """
184
185
186
187
188
189
190
191
192
193
194
        self.compile_args = CompileArgs(
            out_idx=out_idx,
            supply_type=supply_type,
            ref_prog=ref_prog,
            supply_prog=supply_prog,
            rtol=rtol,
            atol=atol,
            max_mismatched_ratio=max_mismatched_ratio,
            skip_check=skip_check,
            cache_input_tensors=cache_input_tensors,
            target=target)
yyttt6's avatar
yyttt6 committed
195

196
197
198
199
200
201
        # If a custom `supply_prog`` is provided, the profiler's `supply_type` setting
        # becomes ineffective. The custom supply program will be used instead.
        if ref_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
            logger.warning("Ignoring `supply_type` passed to `set_compile_args` because "
                           "`ref_prog` is not None.")

yyttt6's avatar
yyttt6 committed
202
        return self
203

204
    def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
205
206
207
208
209
210
211
212
213
214
        """Run the auto-tuning process.

        Args:
            warmup: Number of warmup iterations.
            rep: Number of repetitions for timing.
            timeout: Maximum time per configuration.

        Returns:
            AutotuneResult: Results of the auto-tuning process.
        """
215
        sig = inspect.signature(self.fn)
yyttt6's avatar
yyttt6 committed
216
217
        keys = list(sig.parameters.keys())
        bound_args = sig.bind()
218
219
220
        bound_args.apply_defaults()
        best_latency = 1e8
        best_config = None
yyttt6's avatar
yyttt6 committed
221
        best_jit_context = None
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        def _compile(*config_arg):
            compile_args = self.compile_args
            kernel = tilelang.compile(
                self.fn(*config_arg), out_idx=compile_args.out_idx, target=compile_args.target)
            jit_context = JITContext(
                out_idx=compile_args.out_idx,
                ref_prog=compile_args.ref_prog,
                supply_prog=compile_args.supply_prog,
                rtol=compile_args.rtol,
                atol=compile_args.atol,
                max_mismatched_ratio=compile_args.max_mismatched_ratio,
                skip_check=compile_args.skip_check,
                cache_input_tensors=compile_args.cache_input_tensors,
                kernel=kernel,
                supply_type=compile_args.supply_type,
                target=compile_args.target)
            return jit_context

        if self.jit_compile is None:
            self.jit_compile = _compile

244
        def target_fn(jit_context: JITContext):
245
            # Unpack the context
246
247
            kernel = jit_context.kernel
            supply_type = jit_context.supply_type
248
            skip_check = jit_context.skip_check
249
            cache_input_tensors = jit_context.cache_input_tensors
250
            ref_prog = jit_context.ref_prog
251
            supply_prog = jit_context.supply_prog
252
253
            rtol = jit_context.rtol
            atol = jit_context.atol
254
            max_mismatched_ratio = jit_context.max_mismatched_ratio
255

256
257
            profiler = kernel.get_profiler(tensor_supply_type=supply_type)

258
259
260
261
262
263
264
265
266
267
268
269
270
            # Factory functions for generating input tensors.
            # This encapsulates the logic of using either a custom supply program (`supply_prog`)
            # or the default profiler input generation (`profiler._get_inputs`).
            def get_input_tensors_supply(with_output: bool):

                def func():
                    if supply_prog is not None:
                        return supply_prog(profiler._get_params(with_output=with_output))
                    else:
                        return profiler._get_inputs(with_output=with_output)

                return func

271
            jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            ref_input_tensors_supply = get_input_tensors_supply(with_output=False)

            if cache_input_tensors:
                jit_input_tensors = jit_input_tensors_supply()
                if self.jit_input_tensors is not None:
                    if not check_tensor_list_compatibility(self.jit_input_tensors,
                                                           jit_input_tensors):
                        logger.warning(
                            "Incompatible input tensor properties detected between cached tensors and "
                            "tensors regenerated for the current configuration trial. "
                            "This can happen if different tuning configurations require different input shapes/dtypes "
                            "and input tensor caching is enabled.\n"
                            "To ensure fresh, compatible inputs are generated for every trial "
                            "you can disable caching by setting:\n"
                            "  `cache_input_tensors=False`\n"
                            "within your `.set_compile_args(...)` call.\n")
                    self.jit_input_tensors = jit_input_tensors
                self.jit_input_tensors = jit_input_tensors
            else:
                self.jit_input_tensors = jit_input_tensors_supply()
292
293

            if (not skip_check) and (ref_prog is not None):
294
                profiler.assert_allclose(
295
296
297
298
299
                    ref_prog,
                    input_tensors=self.jit_input_tensors,
                    rtol=rtol,
                    atol=atol,
                    max_mismatched_ratio=max_mismatched_ratio)
300
            latency = profiler.do_bench(
301
                warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
302
            if self.ref_latency_cache is None and ref_prog is not None:
303
                self.ref_input_tensors = ref_input_tensors_supply()
304
                self.ref_latency_cache = profiler.do_bench(
yyttt6's avatar
yyttt6 committed
305
                    ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
306
307
308

            return latency, self.ref_latency_cache

309
310
        config_args = []
        for config in self.configs:
311
312
            new_args = []
            for name, value in bound_args.arguments.items():
yyttt6's avatar
yyttt6 committed
313
                if name not in keys:
314
315
                    new_args.append(value)
                else:
316
317
                    if name not in config:
                        raise ValueError(f"Configuration {config} does not contain key {name}")
318
319
                    new_args.append(config[name])
            new_args = tuple(new_args)
320
321
            config_args.append(new_args)

322
        num_workers = max(1, int(get_available_cpu_count() * 0.9))
323
        pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
324
        futures = []
yyttt6's avatar
yyttt6 committed
325
        future_to_index = {}
326
        for i, config_arg in enumerate(config_args):
yyttt6's avatar
yyttt6 committed
327
328
329
330
            future = pool.submit(
                self.jit_compile,
                *config_arg,
            )
331
332
333
334
335
336
337
338
339
340
341
342
343
            futures.append(future)
            future_to_index[future] = i

        results_with_configs = []
        for future in tqdm(
                concurrent.futures.as_completed(futures),
                total=len(futures),
                desc="Compiling configurations"):
            idx = future_to_index[future]
            config = config_args[idx]
            try:
                result = future.result()
                results_with_configs.append((result, config))
344
345
346
            except Exception as e:
                logger.debug(
                    f"Compilation failed for config {config} at index {idx} with error: {e}")
347
                continue
348
349

        ref_latency = None
350
        progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
351
        for i in progress_bar:
352
            jit_context, config = results_with_configs[i]
353
            try:
354
355
                # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
                # Because tma init may behave strangely with one thread
356
357
358
359
360
361
362
363
364
                # latency, ref_latency = target_fn(jit_context)
                benchmark_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
                future = benchmark_executor.submit(target_fn, jit_context)
                latency, ref_latency = future.result(timeout=timeout)
            except concurrent.futures.TimeoutError:
                logger.info(
                    f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
                )
                continue
365
            except Exception as e:
366
367
368
369
                logger.info(
                    f"An error occurred while testing config {config}, checkout autotuner.log for more details"
                )
                logger.debug(f"Error: {e}")
370
371
                continue

372
            logging.debug(f"Config {config} latency: {latency} at index {i}")
373
374
375
376

            if latency < best_latency:
                best_latency = latency
                best_config = config
yyttt6's avatar
yyttt6 committed
377
                best_jit_context = jit_context
378
379
380
381
382

            progress_bar.set_postfix({"best_latency": best_latency})
            tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")

        pool.shutdown()
383
384
385
386
387
388
389

        if best_jit_context is None:
            error_msg = ("Auto-tuning failed: No configuration successfully "
                         "compiled and passed benchmarking/validation.")
            logger.error(error_msg)
            raise RuntimeError(error_msg)

yyttt6's avatar
yyttt6 committed
390
391
392
393
        return AutotuneResult(
            latency=best_latency,
            config=best_config,
            ref_latency=ref_latency,
394
            libcode=best_jit_context.kernel.get_kernel_source(),
yyttt6's avatar
yyttt6 committed
395
            func=self.fn(*best_config),
396
            kernel=best_jit_context.kernel)
397

yyttt6's avatar
yyttt6 committed
398
    def __call__(self) -> Any:
399
400
401
402
403
        """Make the AutoTuner callable, running the auto-tuning process.

        Returns:
            AutotuneResult: Results of the auto-tuning process.
        """
yyttt6's avatar
yyttt6 committed
404
        return self.run()
405
406


407
def autotune(configs: Any, warmup: int = 25, rep: int = 100, timeout: int = 100) -> AutotuneResult:
408
409
410
411
412
413
414
415
416
417
    """Decorator for auto-tuning tilelang programs.

    Args:
        configs: Configuration space to explore during auto-tuning.
        warmup: Number of warmup iterations before timing.
        rep: Number of repetitions for timing measurements.
        timeout: Maximum time (in seconds) allowed for each configuration.

    Returns:
        Callable: Decorated function that performs auto-tuning.
418
419
    """

yyttt6's avatar
yyttt6 committed
420
421
422
423
424
    def decorator(fn: Callable) -> AutoTuner:
        autotuner = AutoTuner(fn, configs=configs)
        autotuner.jit_compile = fn
        autotuner.run = partial(autotuner.run, warmup, rep, timeout)
        return autotuner
425
426
427
428

    return decorator


429
def jit(out_idx: Optional[List[int]] = None,
430
        supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
431
        ref_prog: Callable = None,
432
        supply_prog: Callable = None,
433
434
        rtol: float = 1e-2,
        atol: float = 1e-2,
435
        max_mismatched_ratio: float = 0.01,
436
        skip_check: bool = False,
437
        cache_input_tensors: bool = True,
438
        target: Literal['auto', 'cuda', 'hip'] = 'auto') -> Callable:
439
440
441
442
    """Just-In-Time compilation decorator for tilelang programs.

    Args:
        out_idx: List of output tensor indices.
443
        supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
444
        ref_prog: Reference program for correctness validation.
445
        supply_prog: Supply program for input tensors.
446
447
448
449
        rtol: Relative tolerance for output validation.
        atol: Absolute tolerance for output validation.
        max_mismatched_ratio: Maximum allowed ratio of mismatched elements.
        skip_check: Whether to skip validation checks.
450
        cache_input_tensors: Whether to cache input tensors for each compilation.
451
452
453
454
455
        target: Target platform ('auto', 'cuda', or 'hip').

    Returns:
        Callable: Decorated function that performs JIT compilation.
    """
456

457
458
459
460
461
462
    # If a custom `supply_prog`` is provided, the profiler's `supply_type` setting
    # becomes ineffective. The custom supply program will be used instead.
    if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
        logger.warning("Ignoring `supply_type` passed to `autotune.jit` because "
                       "`supply_prog` is not None.")

463
464
465
466
    def wrapper(fn: Callable):

        @wraps(fn)
        def decorator(*args, **kwargs) -> float:
467
468

            kernel = tilelang.compile(fn(*args, **kwargs), out_idx=out_idx, target=target)
469
470
471
472

            return JITContext(
                out_idx=out_idx,
                ref_prog=ref_prog,
473
                supply_prog=supply_prog,
474
475
                rtol=rtol,
                atol=atol,
476
                max_mismatched_ratio=max_mismatched_ratio,
477
                skip_check=skip_check,
478
                cache_input_tensors=cache_input_tensors,
479
480
                kernel=kernel,
                supply_type=supply_type,
481
482
483
484
485
                target=target)

        return decorator

    return wrapper
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505


def check_tensor_list_compatibility(
    list1: List[torch.Tensor],
    list2: List[torch.Tensor],
) -> bool:
    """Checks if two lists of tensors are compatible.
    
    Compatibility checks performed include:
    1. Lists have the same length.
    2. Corresponding tensors have the same shape.

    Args:
        list1: First list of tensors.
        list2: Second list of tensors.
    """
    if len(list1) != len(list2):
        return False

    return all(tensor1.shape == tensor2.shape for tensor1, tensor2 in zip(list1, list2))
506
507
508
509
510
511
512
513
514
515
516


def get_available_cpu_count():
    """Gets the number of CPU cores available to the current process.
    """
    try:
        cpu_count = len(os.sched_getaffinity(0))
    except AttributeError:
        cpu_count = os.cpu_count()

    return cpu_count