__init__.py 8.03 KB
Newer Older
1
2
"""The profiler and convert to torch utils"""

3
from typing import List, Literal, Optional, Callable, Any
4
5
6
from functools import partial
import torch
from contextlib import suppress
7
from dataclasses import dataclass
8
9
10
11
12
import tvm
from tilelang.utils.tensor import (
    get_tensor_supply,
    TensorSupplyType,
    torch_assert_close,
13
    adapt_torch2tvm,
14
)
15
16
17
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter
from tilelang.profiler.bench import do_bench
18
19


20
21
22
23
24
25
26
27
28
29
@dataclass
class Profiler:
    """A profiler class for benchmarking and validating kernel implementations.
    
    Attributes:
        params: List of kernel parameters defining the input/output specifications
        result_idx: Indices indicating which parameters are output tensors
        supply_type: Type of tensor supply to use (e.g., random, zeros, etc.)
        adapter: Optional kernel adapter for interfacing with different backends
    """
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    params: List[KernelParam]
    result_idx: List[int]
    supply_type: TensorSupplyType
    adapter: Optional[BaseKernelAdapter] = None

    def __post_init__(self):
        """Initialize tensor supply after dataclass initialization"""
        self.result_idx = self._legalize_result_idx(self.result_idx)
        self.supply = get_tensor_supply(self.supply_type)

    def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]:
        params = self.params
        # result_idx is a list of indices of the output tensors
        if result_idx is None:
            result_idx = []
        elif isinstance(result_idx, int):
            if result_idx > len(params) or result_idx < -len(params):
                raise ValueError(
                    f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
            if result_idx < 0:
                result_idx = len(params) + result_idx
            result_idx = [result_idx]
        elif not isinstance(result_idx, list):
            raise ValueError("result_idx should be a list of integers")

        return result_idx

    def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler":
        self.adapter = adapter
        return self
61
62
63
64
65
66
67
68
69
70

    def _get_inputs(self, with_output=False):
        ins = []
        for i in range(len(self.params)):
            if with_output or i not in self.result_idx:
                ins.append(self.supply(self.params[i]))
        return ins

    def assert_allclose(
        self,
71
        reference_program: Callable,
72
73
74
75
        atol: float = 1e-2,
        rtol: float = 1e-2,
        max_mismatched_ratio=0.01,
    ):
76
77
78
79
80
81
82
83
        """Validates kernel output against a reference implementation.
        
        Args:
            reference_program: Reference implementation to compare against
            atol: Absolute tolerance for comparison
            rtol: Relative tolerance for comparison
            max_mismatched_ratio: Maximum allowed ratio of mismatched elements
        """
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        ins = self._get_inputs()
        ref_outs = reference_program(*ins)
        torch.cuda.synchronize()
        lib_outs = self.func(*ins)
        torch.cuda.synchronize()

        if isinstance(lib_outs, torch.Tensor):
            lib_outs = [lib_outs]
        if isinstance(ref_outs, torch.Tensor):
            ref_outs = [ref_outs]
        assert len(lib_outs) == len(ref_outs)
        # torch.set_printoptions(edgeitems=torch.inf)
        for lhs, rhs in zip(lib_outs, ref_outs):
            # close_mask = torch.isclose(lhs, rhs, rtol=rtol, atol=atol)
            # total_elements = lhs.numel()
            # num_not_close = (~close_mask).sum().item()
            # percentage_not_close = (num_not_close / total_elements) * 100
            # print(f"{percentage_not_close:.2f}% of the elements are not close.")
            # print(f"Total elements: {total_elements}, Not close elements: {num_not_close}")
            torch_assert_close(
                lhs,
                rhs,
                rtol=rtol,
                atol=atol,
                max_mismatched_ratio=max_mismatched_ratio,
            )

    def assert_consistent(self, repeat=10):
112
113
114
115
116
        """Checks for kernel consistency across multiple runs.
        
        Args:
            repeat: Number of times to repeat the consistency check
        """
117
118
119
120
121
122
123
124
125
126
127
128
129
        # Used to check no race condition inside the kernel
        ins = self._get_inputs()
        ref_outs = self.func(*ins)

        for _ in range(repeat):
            lib_outs = self.func(*ins)
            for lhs, rhs in zip(lib_outs, ref_outs):
                assert torch.allclose(lhs, rhs), [
                    "result is not consistent",
                    lhs,
                    rhs,
                ]

130
    def run_once(self, func: Optional[Callable] = None):
131
132
133
134
135
        ins = self._get_inputs()
        if not func:
            func = self.__call__
        return func(*ins)

136
137
138
    def determine_profiler(self,
                           func: Optional[Callable] = None,
                           profiler: Literal["torch", "tvm", "auto"] = "auto"):
139
140
141
142
143
144
145
146
147
        """Determines which profiler backend to use based on function type.
        
        Args:
            func: Function to be profiled
            profiler: Explicitly specified profiler type or "auto" for automatic detection
        
        Returns:
            str: The determined profiler type ("torch" or "tvm")
        """
148
        if profiler == "auto":
149
            if isinstance(func, tvm.runtime.Module):
150
151
152
153
154
                return "tvm"
            else:
                return "torch"
        return profiler

155
156
    def do_bench(
        self,
157
158
159
160
161
        func: Optional[Callable] = None,
        warmup: int = 25,
        rep: int = 100,
        n_warmup: int = 1,
        n_repeat: int = 1,
162
163
        profiler: Literal["torch", "tvm", "auto"] = "auto",
        input_tensors: List[torch.Tensor] = None,
164
    ) -> float:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        """Benchmarks the execution time of a given function.
        
        Args:
            func: Function to benchmark (uses adapter if None)
            warmup: Warmup time in milliseconds
            rep: Number of repetitions for timing
            n_warmup: Number of warmup iterations
            n_repeat: Number of timing iterations
            profiler: Which profiling backend to use
            input_tensors: Optional pre-generated input tensors
            
        Returns:
            float: Average execution time in milliseconds
        """
179
        profiler = self.determine_profiler(func, profiler)
180
        if profiler == "torch":
181
182
183
            if func is None:
                assert self.adapter is not None, "benchmarking function should be provided"
                func = self.adapter
184
185
186
187
188
189
190
191
192
193
            ins = self._get_inputs() if input_tensors is None else input_tensors
            bench_func = partial(func, *ins)
            return do_bench(
                bench_func,
                warmup=warmup,
                rep=rep,
                _n_warmup=n_warmup,
                _n_repeat=n_repeat,
            )
        elif profiler == "tvm":
194
            assert func is not None, "func should not be None"
195
196
            assert isinstance(
                func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"
197

198
199
200
201
202
203
204
205
206
207
208
            ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors)
            target = "cuda"

            with suppress(Exception):
                target = self.mod.imported_modules[0].type_key

            assert target in ["cuda", "hip"], f"Unknown target: {target}"

            device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
            time_evaluator = self.mod.time_evaluator(
                self.mod.entry_name, device, number=rep, repeat=n_repeat)
209
            tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
210
211
212
213
214
            # Transform Latency to ms
            return time_evaluator(*tvm_inputs).mean * 1e3
        else:
            raise ValueError(f"Unknown profiler: {profiler}")

215
216
217
218
    @property
    def func(self):
        assert self.adapter is not None, "adapter should be provided"
        return self.adapter
219

220
221
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        return self.func(*args, **kwds)