importing.py 1.33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import types
4
5
6
7
8
9
from importlib.util import find_spec

from vllm.logger import init_logger

logger = init_logger(__name__)

10
11
HAS_TRITON = (
    find_spec("triton") is not None
12
    or find_spec("pytorch-triton-xpu") is not None  # Not compatible
13
)
14
15

if not HAS_TRITON:
16
17
    logger.info("Triton not installed or not compatible; certain GPU-related"
                " functions will not be available.")
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

class TritonPlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton")
        self.jit = self._dummy_decorator("jit")
        self.autotune = self._dummy_decorator("autotune")
        self.heuristics = self._dummy_decorator("heuristics")
        self.language = TritonLanguagePlaceholder()
        logger.warning_once(
            "Triton is not installed. Using dummy decorators. "
            "Install it via `pip install triton` to enable kernel"
            " compilation.")

    def _dummy_decorator(self, name):

        def decorator(*args, **kwargs):
            if args and callable(args[0]):
                return args[0]
            return lambda f: f

        return decorator


class TritonLanguagePlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton.language")
        self.constexpr = None
        self.dtype = None
        self.int64 = None