__init__.py 28.2 KB
Newer Older
1
2
3
4
5
"""The auto-tune module for tilelang programs.

This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
6

7
import tilelang
8
from tilelang import tvm as tvm
9
from tvm.tir import PrimFunc, Var
10
from tvm.target import Target
11
import inspect
12
13
from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
14
15
from tqdm import tqdm
import logging
16
import functools
17
import concurrent.futures
18
import torch
19
import os
20
import sys
21
import signal
22
23
24
import json
import hashlib
import threading
25
import traceback
26
27
28
29
30
31
from pathlib import Path

from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.jit.param import _P, _RProg
from tilelang.version import __version__
32
33
34
35
36
37
38


class TimeoutException(Exception):
    pass


def timeout_handler(signum, frame):
39
    raise TimeoutException("Operation timed out")
40
41
42
43
44
45
46


def run_with_timeout(func, timeout, *args, **kwargs):
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout)
    try:
        result = func(*args, **kwargs)
47
48
    except Exception as e:
        raise e
49
50
51
52
    finally:
        signal.alarm(0)
    return result

53

54
55
# Configure logging for the autotuner module
# TODO: Consider creating a common logger in utils
56
logger = logging.getLogger(__name__)
57
58
logger.setLevel(logging.DEBUG)
logger.propagate = False
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Lazy handler initialization flag
_logger_handlers_initialized = False


def _init_logger_handlers():
    global _logger_handlers_initialized
    if _logger_handlers_initialized:
        return
    formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
    file_handler = logging.FileHandler('autotuner.log', mode='w')
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(formatter)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    _logger_handlers_initialized = True
78
79


80
81
def get_available_cpu_count() -> int:
    """Gets the number of CPU cores available to the current process.
82
    """
83
84
85
86
    try:
        cpu_count = len(os.sched_getaffinity(0))
    except AttributeError:
        cpu_count = os.cpu_count()
87

88
    return cpu_count
89
90


yyttt6's avatar
yyttt6 committed
91
class AutoTuner:
92
93
94
95
96
97
98
99
100
    """Auto-tuner for tilelang programs.

    This class handles the auto-tuning process by testing different configurations
    and finding the optimal parameters for program execution.

    Args:
        fn: The function to be auto-tuned.
        configs: List of configurations to try during auto-tuning.
    """
101
102
103
    compile_args = CompileArgs()
    profile_args = ProfileArgs()

104
    _kernel_parameters: Optional[Tuple[str, ...]] = None
105
106
    _lock = threading.Lock()  # For thread safety
    _memory_cache = {}  # In-memory cache dictionary
107
    cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
108

yyttt6's avatar
yyttt6 committed
109
    def __init__(self, fn: Callable, configs):
110
111
112
113
114
        self.fn = fn
        self.configs = configs
        self.ref_latency_cache = None
        self.jit_input_tensors = None
        self.ref_input_tensors = None
115
        self.jit_compile = None
116

yyttt6's avatar
yyttt6 committed
117
118
    @classmethod
    def from_kernel(cls, kernel: Callable, configs):
119
120
121
122
123
124
125
126
127
        """Create an AutoTuner instance from a kernel function.

        Args:
            kernel: The kernel function to auto-tune.
            configs: List of configurations to try.

        Returns:
            AutoTuner: A new AutoTuner instance.
        """
yyttt6's avatar
yyttt6 committed
128
129
130
        return cls(kernel, configs)

    def set_compile_args(self,
131
                         out_idx: Union[List[int], int, None] = None,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
                         target: Literal['auto', 'cuda', 'hip'] = 'auto',
                         execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
                         target_host: Union[str, Target] = None,
                         verbose: bool = False,
                         pass_configs: Optional[Dict[str, Any]] = None):
        """Set compilation arguments for the auto-tuner.

        Args:
            out_idx: List of output tensor indices.
            target: Target platform.
            execution_backend: Execution backend to use for kernel execution.
            target_host: Target host for cross-compilation.
            verbose: Whether to enable verbose output.
            pass_configs: Additional keyword arguments to pass to the Compiler PassContext.

        Returns:
            AutoTuner: Self for method chaining.
        """
        self.compile_args = CompileArgs(
            out_idx=out_idx,
            target=target,
            execution_backend=execution_backend,
            target_host=target_host,
            verbose=verbose,
            pass_configs=pass_configs)

        return self

    def set_profile_args(self,
                         warmup: int = 25,
                         rep: int = 100,
                         timeout: int = 30,
164
                         supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
yyttt6's avatar
yyttt6 committed
165
                         ref_prog: Callable = None,
166
                         supply_prog: Callable = None,
yyttt6's avatar
yyttt6 committed
167
168
169
170
                         rtol: float = 1e-2,
                         atol: float = 1e-2,
                         max_mismatched_ratio: float = 0.01,
                         skip_check: bool = False,
171
                         manual_check_prog: Callable = None,
172
                         cache_input_tensors: bool = False):
173
        """Set profiling arguments for the auto-tuner.
174
175

        Args:
176
            supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided.
177
            ref_prog: Reference program for validation.
178
            supply_prog: Supply program for input tensors.
179
180
181
182
            rtol: Relative tolerance for validation.
            atol: Absolute tolerance for validation.
            max_mismatched_ratio: Maximum allowed mismatch ratio.
            skip_check: Whether to skip validation.
183
            manual_check_prog: Manual check program for validation.
184
            cache_input_tensors: Whether to cache input tensors.
185
186
187
            warmup: Number of warmup iterations.
            rep: Number of repetitions for timing.
            timeout: Maximum time per configuration.
188
189
190
191

        Returns:
            AutoTuner: Self for method chaining.
        """
192
        self.profile_args = ProfileArgs(
193
194
195
196
197
198
199
            supply_type=supply_type,
            ref_prog=ref_prog,
            supply_prog=supply_prog,
            rtol=rtol,
            atol=atol,
            max_mismatched_ratio=max_mismatched_ratio,
            skip_check=skip_check,
200
            manual_check_prog=manual_check_prog,
201
            cache_input_tensors=cache_input_tensors,
202
203
204
            warmup=warmup,
            rep=rep,
            timeout=timeout)
yyttt6's avatar
yyttt6 committed
205

206
        # If a custom `supply_prog` is provided, the profiler's `supply_type` setting
207
        # becomes ineffective. The custom supply program will be used instead.
208
209
210
        if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto:
            logger.warning("Ignoring `supply_type` passed to `set_profile_args` because "
                           "`supply_prog` is not None.")
211

yyttt6's avatar
yyttt6 committed
212
        return self
213

214
215
216
217
218
    def set_kernel_parameters(self, parameters: Tuple[str, ...]):
        # for cache key generation
        self._kernel_parameters = parameters

    def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
219
220
        """Generate a cache key for the auto-tuning process.
        """
221
222
223
224
225
226
227
228
229
        # extract parameters from the function signature
        op_parameters = []
        for _, default_value in parameters.items():
            if default_value.default is not inspect.Parameter.empty:
                op_parameters.append(default_value.default)

        if self._kernel_parameters is not None:
            op_parameters += self._kernel_parameters

230
231
232
        func_source = inspect.getsource(self.fn)
        key_data = {
            "version": __version__,
233
            "op_parameters": tuple(op_parameters),
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            "func_source": func_source,
            "configs": self.configs,
            "compile_args": hash(self.compile_args),
            "profile_args": hash(self.profile_args),
        }
        # Sort keys to ensure consistency
        key_string = json.dumps(key_data, sort_keys=True)
        return hashlib.sha256(key_string.encode()).hexdigest()

    def _save_result_to_disk(self, key, result: AutotuneResult):
        result.save_to_disk(self.cache_dir / key)

    def _load_result_from_disk(self, key) -> AutotuneResult:
        result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args)
        return result

250
    def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
251
252
253
254
255
256
257
258
259
260
        """Run the auto-tuning process.

        Args:
            warmup: Number of warmup iterations.
            rep: Number of repetitions for timing.
            timeout: Maximum time per configuration.

        Returns:
            AutotuneResult: Results of the auto-tuning process.
        """
261
        _init_logger_handlers()
262

263
264
265
266
267
        sig = inspect.signature(self.fn)
        parameters = sig.parameters

        key = self.generate_cache_key(parameters)

268
269
270
271
        with self._lock:
            if is_cache_enabled():
                # First check in-memory cache
                if key in self._memory_cache:
272
                    logger.warning("Found kernel in memory cache. For better performance," \
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                                        " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.")
                    return self._memory_cache[key]

                # Then check disk cache
                result = self._load_result_from_disk(key)
                if result is not None:
                    # Populate memory cache with disk result
                    self._memory_cache[key] = result
                    return result

        best_latency: float = 1e8
        best_config: Optional[Dict[str, Any]] = None
        best_kernel: Optional[tilelang.JITKernel] = None

        def _compile(**config_arg) -> tilelang.JITKernel:
288
            compile_args = self.compile_args
289
            return compile_args.compile_program(self.fn(**config_arg))
290
291
292
293

        if self.jit_compile is None:
            self.jit_compile = _compile

294
        def target_fn(jit_kernel: tilelang.JITKernel):
295
            # Unpack the context
296
297
298
299
300
301
302
303
304
305
306
307
            profile_args = self.profile_args
            supply_type = profile_args.supply_type
            skip_check = profile_args.skip_check
            manual_check_prog = profile_args.manual_check_prog
            cache_input_tensors = profile_args.cache_input_tensors
            ref_prog = profile_args.ref_prog
            supply_prog = profile_args.supply_prog
            rtol = profile_args.rtol
            atol = profile_args.atol
            max_mismatched_ratio = profile_args.max_mismatched_ratio

            profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type)
308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            # Factory functions for generating input tensors.
            # This encapsulates the logic of using either a custom supply program (`supply_prog`)
            # or the default profiler input generation (`profiler._get_inputs`).
            def get_input_tensors_supply(with_output: bool):

                def func():
                    if supply_prog is not None:
                        return supply_prog(profiler._get_params(with_output=with_output))
                    else:
                        return profiler._get_inputs(with_output=with_output)

                return func

            jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
            ref_input_tensors_supply = get_input_tensors_supply(with_output=False)

            if cache_input_tensors:
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                params = profiler._get_params(with_output=False)
                if self.jit_input_tensors is None:
                    self.jit_input_tensors = jit_input_tensors_supply()
                else:
                    # check if the cached tensors are compatible with the current configuration
                    assert len(params) == len(
                        self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)"
                    for p, c in zip(params, self.jit_input_tensors):
                        if not isinstance(c, torch.Tensor):
                            # skip non-tensor inputs checking
                            continue

                        # Check tensor compatibility using generator expression
                    if len(params) == len(self.jit_input_tensors):
                        def shape_equal(a, b):
                            if len(a.shape) != len(b.shape):
                                return False
                            return all(a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) for a_dim, b_dim in zip(a.shape, b.shape))

                        if p.dtype != c.dtype or not shape_equal(p, c):
                            logger.warning(
                                "\nIncompatible input tensor properties detected between cached tensors and "
                                "tensors regenerated for the current configuration trial. "
                                "This can happen if different tuning configurations require different input shapes/dtypes "
                                "and input tensor caching is enabled.\n"
                                "To ensure fresh, compatible inputs are generated for every trial "
                                "you can disable caching by setting:\n"
                                "  `cache_input_tensors=False`\n"
                                "within your `.set_compile_args(...)` call.\n")
                            # otherwise, regenerate the input tensors for safety
                            self.jit_input_tensors = jit_input_tensors_supply()
                            break
358
359
            else:
                self.jit_input_tensors = jit_input_tensors_supply()
360
361

            if (not skip_check) and (ref_prog is not None):
362
363
364
                if manual_check_prog is not None:
                    profiler.manual_assert_close(
                        ref_prog,
365
                        input_tensors=self.jit_input_tensors,
366
367
368
369
                        manual_check_prog=manual_check_prog)
                else:
                    profiler.assert_allclose(
                        ref_prog,
370
                        input_tensors=self.jit_input_tensors,
371
372
373
                        rtol=rtol,
                        atol=atol,
                        max_mismatched_ratio=max_mismatched_ratio)
374
375
            latency = profiler.do_bench(
                warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
376

377
            if self.ref_latency_cache is None and ref_prog is not None:
378
                self.ref_input_tensors = ref_input_tensors_supply()
379
                self.ref_latency_cache = profiler.do_bench(
yyttt6's avatar
yyttt6 committed
380
                    ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
381
382
383

            return latency, self.ref_latency_cache

384
385
        config_args = []
        for config in self.configs:
386
            new_kwargs = {}
387
            keys = config.keys()
388
389
390
            for name, _ in parameters.items():
                if name in config:
                    new_kwargs[name] = config[name]
391
392
393
            unused_keys = set(keys) - set(new_kwargs.keys())
            if len(unused_keys) > 0:
                raise ValueError(f"Unused keys in config: {unused_keys}")
394
            config_args.append(new_kwargs)
395

396
        num_workers = max(1, int(get_available_cpu_count() * 0.9))
397
        pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers)
398
        futures = []
yyttt6's avatar
yyttt6 committed
399
        future_to_index = {}
400

401
        def device_wrapper(func, device, **config_arg):
402
            torch.cuda.set_device(device)
403
            return func(**config_arg)
404

405
        for i, config_arg in enumerate(config_args):
yyttt6's avatar
yyttt6 committed
406
            future = pool.submit(
407
                functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()),
408
                **config_arg,
yyttt6's avatar
yyttt6 committed
409
            )
410
411
412
413
414
415
416
417
418
419
420
421
422
            futures.append(future)
            future_to_index[future] = i

        results_with_configs = []
        for future in tqdm(
                concurrent.futures.as_completed(futures),
                total=len(futures),
                desc="Compiling configurations"):
            idx = future_to_index[future]
            config = config_args[idx]
            try:
                result = future.result()
                results_with_configs.append((result, config))
423
424
425
            except Exception as e:
                logger.debug(
                    f"Compilation failed for config {config} at index {idx} with error: {e}")
426
                continue
427
428

        ref_latency = None
429
        progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
430
        for i in progress_bar:
431
            jit_kernel, config = results_with_configs[i]
432
            try:
433
434
                # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution
                # Because tma init may behave strangely with one thread
435
436
                # latency, ref_latency = target_fn(jit_kernel)
                latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel)
437
            except TimeoutException:
438
439
440
441
                logger.info(
                    f"A timeout occurred while testing config {config}, checkout autotuner.log for more details"
                )
                continue
442
            except Exception:
443
444
445
                logger.info(
                    f"An error occurred while testing config {config}, checkout autotuner.log for more details"
                )
446
                logger.debug(f"Error: {traceback.format_exc()}")
447
448
449
450
451
                continue

            if latency < best_latency:
                best_latency = latency
                best_config = config
452
                best_kernel = jit_kernel
453
454
455
456
457

            progress_bar.set_postfix({"best_latency": best_latency})
            tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}")

        pool.shutdown()
458

459
        if best_kernel is None:
460
461
462
463
464
            error_msg = ("Auto-tuning failed: No configuration successfully "
                         "compiled and passed benchmarking/validation.")
            logger.error(error_msg)
            raise RuntimeError(error_msg)

465
466
467
468
469
470
471
        best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result(
            latency=best_latency,
            config=best_config,
            ref_latency=ref_latency,
        )

        autotuner_result = AutotuneResult(
yyttt6's avatar
yyttt6 committed
472
473
474
            latency=best_latency,
            config=best_config,
            ref_latency=ref_latency,
475
476
477
478
479
480
481
482
483
484
485
486
487
488
            libcode=best_kernel.get_kernel_source(),
            func=best_kernel.prim_func,
            kernel=best_kernel)

        if self.compile_args.execution_backend == "dlpack":
            logger.warning("DLPack backend does not support cache saving to disk.")
        else:
            with self._lock:
                if is_cache_enabled():
                    self._save_result_to_disk(key, autotuner_result)

        self._memory_cache[key] = autotuner_result

        return autotuner_result
489

yyttt6's avatar
yyttt6 committed
490
    def __call__(self) -> Any:
491
492
493
494
495
        """Make the AutoTuner callable, running the auto-tuning process.

        Returns:
            AutotuneResult: Results of the auto-tuning process.
        """
yyttt6's avatar
yyttt6 committed
496
        return self.run()
497
498


499
500
501
class _AutoTunerImplementation:
    # Overload __init__ to help type checkers understand the effect of return_program
    # The '-> None' is for __init__ itself. The crucial part is Literal for return_program.
502

503
504
505
506
    warmup: int = 25
    rep: int = 100
    timeout: int = 100
    configs: Any = None
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
    ref_prog: Callable = None
    supply_prog: Callable = None
    rtol: float = 1e-2
    atol: float = 1e-2
    max_mismatched_ratio: float = 0.01
    skip_check: bool = False
    manual_check_prog: Callable = None
    cache_input_tensors: bool = False

    def __init__(self,
                 configs: Any,
                 warmup: int = 25,
                 rep: int = 100,
                 timeout: int = 100,
                 supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
                 ref_prog: Callable = None,
                 supply_prog: Callable = None,
                 rtol: float = 1e-2,
                 atol: float = 1e-2,
                 max_mismatched_ratio: float = 0.01,
                 skip_check: bool = False,
                 manual_check_prog: Callable = None,
                 cache_input_tensors: bool = False) -> None:
531
        """Initialize the AutoTunerImplementation.
532

533
534
535
536
537
        Args:
            configs: Configuration space to explore during auto-tuning.
            warmup: Number of warmup iterations before timing.
            rep: Number of repetitions for timing measurements.
            timeout: Maximum time (in seconds) allowed for each configuration.
538
539
540
541
542
543
544
545
546
            supply_type: Strategy for generating input tensors (random/zeros/etc)
            ref_prog: Reference implementation for validation
            supply_prog: Custom function to provide input tensors
            rtol: Relative tolerance for numerical validation
            atol: Absolute tolerance for numerical validation
            max_mismatched_ratio: Allowed percentage of mismatched values
            skip_check: Bypass validation against reference implementation
            manual_check_prog: Custom validation function
            cache_input_tensors: Reuse input tensors across trials
547
        """
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        # Configuration and benchmarking parameters
        self.configs = configs  # Search space of tuning configurations
        self.warmup = warmup  # Warmup iterations for stable measurements
        self.rep = rep  # Measurement repetitions for statistics
        self.timeout = timeout  # Per-configuration timeout threshold

        # Tensor handling and validation setup
        self.supply_type = supply_type  # Input tensor generation strategy
        self.ref_prog = ref_prog  # Ground truth implementation
        self.supply_prog = supply_prog  # Custom input data provider
        self.rtol = rtol  # Relative error tolerance
        self.atol = atol  # Absolute error tolerance
        self.max_mismatched_ratio = max_mismatched_ratio  # Allowed mismatch

        # Validation control flags
        self.skip_check = skip_check  # Bypass accuracy verification
        self.manual_check_prog = manual_check_prog  # Custom validation
        self.cache_input_tensors = cache_input_tensors  # Reuse inputs

        # Cache for storing tuned kernel implementations
        self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {}  # (args, kwargs) -> compiled kernel
569

570
571
572
573
574
    # This tells the type checker what the *wrapper* function will return.
    # this is for linting, please do not remove it.
    @overload
    def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
        ...
575

576
577
578
    @overload
    def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]:
        ...
579

580
581
582
583
584
585
    # Actual implementation of __call__
    def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]:
        warmup = self.warmup
        rep = self.rep
        timeout = self.timeout
        configs = self.configs
586

587
588
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
589

590
591
592
            key_args_tuple = args
            key_kwargs_tuple = tuple(sorted(kwargs.items()))
            key = (key_args_tuple, key_kwargs_tuple)
593

594
            if key not in self._tuner_cache:
595

596
597
                def jit_compile(**config_arg):
                    return fn(*args, **kwargs, __tune_params=config_arg)
598

599
600
                compile_arguments = fn(__return_compile_arguments=True)

601
602
603
604
605
606
607
608
609
610
611
                autotuner = AutoTuner(
                    fn, configs=configs).set_profile_args(
                        supply_type=self.supply_type,
                        ref_prog=self.ref_prog,
                        supply_prog=self.supply_prog,
                        rtol=self.rtol,
                        atol=self.atol,
                        max_mismatched_ratio=self.max_mismatched_ratio,
                        skip_check=self.skip_check,
                        manual_check_prog=self.manual_check_prog,
                        cache_input_tensors=self.cache_input_tensors,
612
613
614
615
616
617
618
                    ).set_compile_args(
                        out_idx=compile_arguments['out_idx'],
                        execution_backend=compile_arguments['execution_backend'],
                        target=compile_arguments['target'],
                        target_host=compile_arguments['target_host'],
                        verbose=compile_arguments['verbose'],
                        pass_configs=compile_arguments['pass_configs'],
619
                    )
620

621
                autotuner.jit_compile = jit_compile
622
623
                autotuner.set_kernel_parameters(key)

624
                autotuner.run = partial(autotuner.run, warmup, rep, timeout)
625

626
                artifact = autotuner.run()
627

628
                self._tuner_cache[key] = artifact.kernel
629

630
            return self._tuner_cache[key]
631

632
        return wrapper
633
634


635
def autotune(  # This is the new public interface
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
    func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
    *,  # Indicates subsequent arguments are keyword-only
    configs: Any,
    # profile arguments
    warmup: int = 25,
    rep: int = 100,
    timeout: int = 100,
    # compile arguments
    supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto,
    ref_prog: Callable = None,
    supply_prog: Callable = None,
    rtol: float = 1e-2,
    atol: float = 1e-2,
    max_mismatched_ratio: float = 0.01,
    skip_check: bool = False,
    manual_check_prog: Callable = None,
    cache_input_tensors: bool = False,
):
654
    """
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
    Just-In-Time (JIT) compiler decorator for TileLang functions.

    This decorator can be used without arguments (e.g., `@tilelang.jit`):
       Applies JIT compilation with default settings.

    Parameters
    ----------
    func_or_out_idx : Any, optional
        If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter.
        If using `@tilelang.jit` directly on a function, this argument is implicitly
        the function to be decorated (and `out_idx` will be `None`).
    target : Union[str, Target], optional
        Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
    target_host : Union[str, Target], optional
        Target host for cross-compilation. Defaults to None.
    execution_backend : Literal["dlpack", "ctypes", "cython"], optional
        Backend for kernel execution and argument passing. Defaults to "cython".
    verbose : bool, optional
        Enables verbose logging during compilation. Defaults to False.
    pass_configs : Optional[Dict[str, Any]], optional
        Configurations for TVM's pass context. Defaults to None.
    debug_root_path : Optional[str], optional
        Directory to save compiled kernel source for debugging. Defaults to None.

    Returns
    -------
    Callable
        Either a JIT-compiled wrapper around the input function, or a configured decorator
        instance that can then be applied to a function.
684
    """
685
686
687
688
689
690
691
692
693
694
695
696
    if callable(func):
        # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults)
        # This is a placeholder for a real auto tuner implementation
        raise ValueError(
            "Use tilelang.autotune to decorate func without arguments is not supported yet.")
    elif isinstance(func, PrimFunc):
        raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.")
    else:
        # Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx.
        # Create a _AutoTunerImplementation instance with the provided/defaulted arguments.
        # This instance is a decorator that will be applied to the function later.
        configured_decorator = _AutoTunerImplementation(
697
698
699
700
701
702
703
704
705
706
707
708
709
710
            configs=configs,
            warmup=warmup,
            rep=rep,
            timeout=timeout,
            supply_type=supply_type,
            ref_prog=ref_prog,
            supply_prog=supply_prog,
            rtol=rtol,
            atol=atol,
            max_mismatched_ratio=max_mismatched_ratio,
            skip_check=skip_check,
            manual_check_prog=manual_check_prog,
            cache_input_tensors=cache_input_tensors,
        )
711
        return configured_decorator