importing.py 1.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import types
5
6
7
8
9
10
from importlib.util import find_spec

from vllm.logger import init_logger

logger = init_logger(__name__)

11
12
HAS_TRITON = (
    find_spec("triton") is not None
13
    or find_spec("pytorch-triton-xpu") is not None  # Not compatible
14
)
15
16

if not HAS_TRITON:
17
18
    logger.info("Triton not installed or not compatible; certain GPU-related"
                " functions will not be available.")
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

class TritonPlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton")
        self.jit = self._dummy_decorator("jit")
        self.autotune = self._dummy_decorator("autotune")
        self.heuristics = self._dummy_decorator("heuristics")
        self.language = TritonLanguagePlaceholder()
        logger.warning_once(
            "Triton is not installed. Using dummy decorators. "
            "Install it via `pip install triton` to enable kernel"
            " compilation.")

    def _dummy_decorator(self, name):

        def decorator(*args, **kwargs):
            if args and callable(args[0]):
                return args[0]
            return lambda f: f

        return decorator


class TritonLanguagePlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton.language")
        self.constexpr = None
        self.dtype = None
        self.int64 = None