kernel_cache.py 14.5 KB
Newer Older
1
2
3
"""The cache utils with class and database persistence - KernelCache Class"""

import json
4
5
import logging
import os
6
import shutil
7
import threading
8
import uuid
9
from hashlib import sha256
10
11
12
from typing import Callable, List, Literal, Optional, Union

import cloudpickle
13
14
15
from tvm.target import Target
from tvm.tir import PrimFunc

16
from tilelang.engine.param import KernelParam
17
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled
18
from tilelang.jit import JITKernel
19
from tilelang.version import __version__
20
21

KERNEL_PATH = "kernel.cu"
22
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
23
KERNEL_LIB_PATH = "kernel_lib.so"
24
25
KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
26
PARAMS_PATH = "params.pkl"
27
28
29
30
31


class KernelCache:
    """
    Caches compiled kernels using a class and database persistence to avoid redundant compilation.
32
33
    Cache files:
        kernel.cu: The compiled kernel source code
34
        wrapped_kernel.cu: The compiled wrapped kernel source code
35
36
        kernel_lib.so: The compiled kernel library
        params.pkl: The compiled kernel parameters
37
    """
38

39
40
    _instance = None  # For implementing singleton pattern
    _lock = threading.Lock()  # For thread safety
41
    _memory_cache = {}  # In-memory cache dictionary
42
    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython"
43

44
    def __new__(cls):
45
46
47
48
49
50
        """
        Implements singleton pattern for KernelCache class.

        Returns:
            KernelCache: The singleton instance of KernelCache.
        """
51
52
        if cls._instance is None:
            with cls._lock:
53
                if cls._instance is None:  # Double-checked locking
54
                    instance = super().__new__(cls)
55
                    KernelCache._create_dirs()
56
                    instance.logger = logging.getLogger(__name__)
57
                    instance.logger.setLevel(logging.DEBUG)
58
                    instance._memory_cache = {}  # Initialize memory cache
59
                    cls._instance = instance
60
61
        return cls._instance

62
63
64
65
66
    @staticmethod
    def _create_dirs():
        os.makedirs(TILELANG_CACHE_DIR, exist_ok=True)
        os.makedirs(TILELANG_TMP_DIR, exist_ok=True)

67
68
69
70
    def _generate_key(
        self,
        func: Callable,
        out_idx: List[int],
71
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
72
73
74
        args=None,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
75
        pass_configs: dict = None,
76
    ) -> str:
77
78
        """
        Generates a unique hash key for caching compiled kernels.
79

80
81
82
83
84
85
86
87
88
89
90
        Args:
            func (Callable): The function to be compiled.
            out_idx (List[int]): Indices specifying which outputs to return.
            execution_backend (Literal): Backend type for execution. Defaults to "cython".
            args: Arguments passed to the function.
            target (Union[str, Target]): Compilation target platform. Defaults to "auto".
            target_host (Union[str, Target], optional): Host target platform.

        Returns:
            str: SHA256 hash key for the kernel configuration.
        """
91
        self.execution_backend = execution_backend
92
        func_binary = cloudpickle.dumps(func.script(show_meta=True))
93
        key_data = {
94
            "version": __version__,
95
            "func": sha256(func_binary).hexdigest(),  # Use SHA256 to generate hash key
96
            "out_idx": (tuple(out_idx) if isinstance(out_idx, (list, tuple)) else [out_idx]),
97
98
99
100
101
102
            "args_repr": tuple(
                repr(arg) for arg in args
            ),  # Use repr to serialize arguments, may need more robust serialization
            "target": str(target),
            "target_host": str(target_host) if target_host else None,
            "execution_backend": execution_backend,
103
            "pass_configs": pass_configs,
104
        }
105
106
107
108
        # Sort keys to ensure consistency
        key_string = json.dumps(key_data, sort_keys=True)
        # Use SHA256 to generate hash key
        return sha256(key_string.encode()).hexdigest()
109
110
111
112
113
114
115
116

    def cached(
        self,
        func: PrimFunc = None,
        out_idx: List[int] = None,
        *args,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
117
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
118
119
        verbose: bool = False,
        pass_configs: dict = None,
120
        compile_flags: Optional[List[str]] = None,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    ) -> JITKernel:
        """
        Caches and reuses compiled kernels to avoid redundant compilation.

        Args:
            func: Function to be compiled or a prepared PrimFunc
            out_idx: Indices specifying which outputs to return
            target: Compilation target platform
            target_host: Host target platform
            *args: Arguments passed to func

        Returns:
            JITKernel: The compiled kernel, either freshly compiled or from cache
        """
135
136
137
138
139
140
141
142
143
        if not is_cache_enabled():
            return JITKernel(
                func,
                out_idx=out_idx,
                execution_backend=execution_backend,
                target=target,
                target_host=target_host,
                verbose=verbose,
                pass_configs=pass_configs,
144
                compile_flags=compile_flags,
145
146
            )

147
148
149
150
151
152
        key = self._generate_key(
            func=func,
            out_idx=out_idx,
            execution_backend=execution_backend,
            args=args,
            target=target,
153
154
155
            target_host=target_host,
            pass_configs=pass_configs,
        )
156
157
158
        with self._lock:
            # First check in-memory cache
            if key in self._memory_cache:
159
160
                self.logger.warning("Found kernel in memory cache. For better performance," \
                                    " consider using `@tilelang.jit` instead of direct kernel caching.")
161
162
                return self._memory_cache[key]

163
164
165
            if verbose:
                self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")

166
            # Then check disk cache
167
168
            kernel = self._load_kernel_from_disk(key, target, target_host, out_idx,
                                                 execution_backend, pass_configs, func)
169
            if kernel is not None:
170
171
172
                if verbose:
                    self.logger.debug(
                        f"Found kernel in disk cache for {func.attrs['global_symbol']}")
173
174
                # Populate memory cache with disk result
                self._memory_cache[key] = kernel
175
176
                return kernel

177
178
179
180
181
182
183
184
185
186
187
188
189
        # Compile kernel if cache miss; leave critical section
        kernel = JITKernel(
            func,
            out_idx=out_idx,
            execution_backend=execution_backend,
            target=target,
            target_host=target_host,
            verbose=verbose,
            pass_configs=pass_configs,
        )
        if execution_backend == "dlpack":
            self.logger.warning("DLPack backend does not support cache saving to disk.")
        else:
190
191
            with self._lock:
                if is_cache_enabled():
192
                    self._save_kernel_to_disk(key, kernel, func)
193
194
195

        # Store in memory cache after compilation
        self._memory_cache[key] = kernel
196
        return kernel
197
198
199
200
201

    def clear_cache(self):
        """
        Clears the entire kernel cache, including both in-memory and disk cache.
        """
202
        with self._lock:
203
            self._memory_cache.clear()  # Clear in-memory cache
204
205
206
207
            self._clear_disk_cache()  # Clear disk cache

    def _get_cache_path(self, key: str) -> str:
        """
208
209
210
211
212
213
214
        Gets the filesystem path for a cached kernel.

        Args:
            key (str): The hash key identifying the kernel.

        Returns:
            str: Absolute path to the cache directory for this kernel.
215
        """
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        return os.path.join(TILELANG_CACHE_DIR, key)

    @staticmethod
    def _load_binary(path: str):
        with open(path, "rb") as file:
            binary = file.read()
        return binary

    @staticmethod
    def _safe_write_file(path: str, mode: str, operation: Callable):
        # Random a temporary file within the same FS as the cache directory
        temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
        with open(temp_path, mode) as temp_file:
            operation(temp_file)

        # Use atomic POSIX replace, so other processes cannot see a partial write
        os.replace(temp_path, path)
233
234
235

    def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = None):
        """
236
237
238
239
240
241
242
243
244
245
246
247
248
        Persists a compiled kernel to disk cache.

        Args:
            key (str): The hash key identifying the kernel.
            kernel (JITKernel): The compiled kernel to be saved.
            func (Callable, optional): The original function.

        Note:
            Saves the following files:
            - kernel.cu: The compiled kernel source code
            - wrapped_kernel.cu: The wrapped kernel source code
            - kernel_lib.so: The compiled kernel library
            - params.pkl: The serialized kernel parameters
249
250
251
252
        """
        cache_path = self._get_cache_path(key)
        os.makedirs(cache_path, exist_ok=True)  # Ensure directory exists

253
254
255
        # Save kernel source code
        try:
            kernel_path = os.path.join(cache_path, KERNEL_PATH)
256
            if kernel.artifact.kernel_source is not None:
257
258
                KernelCache._safe_write_file(kernel_path, "w",
                                             lambda file: file.write(kernel.artifact.kernel_source))
259
260
261
262
        except Exception as e:
            self.logger.error(f"Error saving kernel source code to disk: {e}")

        # Save wrapped kernel source code
263
        try:
264
            wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
265
266
267
            KernelCache._safe_write_file(
                wrapped_kernel_path, "w",
                lambda file: file.write(kernel.adapter.get_kernel_source()))
268
        except Exception as e:
269
            self.logger.error(f"Error saving wrapped kernel source code to disk: {e}")
270

271
        # Save the kernel library
272
        try:
273
274
275
            # Save CUBIN or SO file
            kernel_lib_path = KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH
            kernel_lib_path = os.path.join(cache_path, kernel_lib_path)
276
            src_lib_path = kernel.adapter.libpath
277
278
279
280
281
            KernelCache._safe_write_file(
                kernel_lib_path, "wb",
                lambda file: file.write(KernelCache._load_binary(src_lib_path)))

            # Save an extra Python file for NVRTC
282
            if self.execution_backend == "nvrtc":
283
284
285
286
287
                kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
                src_lib_path = src_lib_path.replace(".cubin", ".py")
                KernelCache._safe_write_file(
                    kernel_py_path, "wb",
                    lambda file: file.write(KernelCache._load_binary(src_lib_path)))
288
289
290
291
292
293
        except Exception as e:
            self.logger.error(f"Error saving kernel library to disk: {e}")

        # Save kernel parameters
        try:
            params_path = os.path.join(cache_path, PARAMS_PATH)
294
295
            KernelCache._safe_write_file(params_path, "wb",
                                         lambda file: cloudpickle.dump(kernel.params, file))
296
297
298
        except Exception as e:
            self.logger.error(f"Error saving kernel parameters to disk: {e}")

299
300
301
302
303
304
    def _load_kernel_from_disk(
        self,
        key: str,
        target: Union[str, Target] = "auto",
        target_host: Union[str, Target] = None,
        out_idx: List[int] = None,
305
        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
306
307
        pass_configs: dict = None,
        func: Callable = None,
308
    ) -> Optional[JITKernel]:
309
        """
310
311
312
313
314
315
316
317
318
319
320
321
322
        Loads a previously compiled kernel from disk cache.

        Args:
            key (str): The hash key identifying the kernel.
            target (Union[str, Target]): Compilation target platform. Defaults to "auto".
            target_host (Union[str, Target], optional): Host target platform.
            out_idx (List[int], optional): Indices specifying which outputs to return.
            execution_backend (Literal): Backend type for execution. Defaults to "cython".
            pass_configs (dict, optional): Configuration for compiler passes.
            func (Callable, optional): The original function.

        Returns:
            JITKernel: The loaded kernel if found, None otherwise.
323
324
        """
        cache_path = self._get_cache_path(key)
325
326
327
328
329
        wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
        kernel_lib_path = os.path.join(
            cache_path, KERNEL_CUBIN_PATH if self.execution_backend == "nvrtc" else KERNEL_LIB_PATH)
        params_path = os.path.join(cache_path, PARAMS_PATH)
        if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
330
            return None
331
332
333
334

        kernel_global_source: Optional[str] = None
        kernel_params: Optional[List[KernelParam]] = None

335
        # Load the kernel source file (optional)
336
        try:
337
338
            with open(wrapped_kernel_path, "r") as f:
                kernel_global_source = f.read()
339
        except Exception as e:
340
341
342
            self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")

        # Load kernel parameters
343
        try:
344
345
            with open(params_path, "rb") as f:
                kernel_params = cloudpickle.load(f)
346
347
348
        except Exception as e:
            self.logger.error(f"Error loading kernel parameters from disk: {e}")

349
350
351
352
353
354
        if kernel_global_source and kernel_params:
            return JITKernel.from_database(
                func=func,
                kernel_global_source=kernel_global_source,
                kernel_lib_path=kernel_lib_path,
                params=kernel_params,
355
356
357
                target=target,
                target_host=target_host,
                out_idx=out_idx,
358
                execution_backend=execution_backend,
359
360
361
362
363
364
365
                pass_configs=pass_configs,
            )
        else:
            return None

    def _clear_disk_cache(self):
        """
366
        Removes all cached kernels from disk.
367

368
369
370
        Note:
            This operation will delete the entire cache directory and recreate it empty.
            Use with caution as this operation cannot be undone.
371
372
        """
        try:
373
374
375
376
377
            # Delete the entire cache directory
            shutil.rmtree(TILELANG_CACHE_DIR)

            # Re-create the cache directory
            KernelCache._create_dirs()
378
379
        except Exception as e:
            self.logger.error(f"Error clearing disk cache: {e}")