importing.py 2.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import types
5
6
7
8
9
10
from importlib.util import find_spec

from vllm.logger import init_logger

logger = init_logger(__name__)

11
12
HAS_TRITON = (
    find_spec("triton") is not None
13
    or find_spec("pytorch-triton-xpu") is not None  # Not compatible
14
)
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
if HAS_TRITON:
    try:
        from triton.backends import backends

        # It's generally expected that x.driver exists and has
        # an is_active method.
        # The `x.driver and` check adds a small layer of safety.
        active_drivers = [
            x.driver for x in backends.values()
            if x.driver and x.driver.is_active()
        ]
        if len(active_drivers) != 1:
            logger.info(
                "Triton is installed but %d active driver(s) found "
                "(expected 1). Disabling Triton to prevent runtime errors.",
                len(active_drivers))
            HAS_TRITON = False
    except ImportError:
        # This can occur if Triton is partially installed or triton.backends
        # is missing.
        logger.warning(
            "Triton is installed, but `triton.backends` could not be imported. "
            "Disabling Triton.")
        HAS_TRITON = False
    except Exception as e:
        # Catch any other unexpected errors during the check.
        logger.warning(
            "An unexpected error occurred while checking Triton active drivers:"
            " %s. Disabling Triton.", e)
        HAS_TRITON = False
45
46

if not HAS_TRITON:
47
48
    logger.info("Triton not installed or not compatible; certain GPU-related"
                " functions will not be available.")
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

class TritonPlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton")
        self.jit = self._dummy_decorator("jit")
        self.autotune = self._dummy_decorator("autotune")
        self.heuristics = self._dummy_decorator("heuristics")
        self.language = TritonLanguagePlaceholder()
        logger.warning_once(
            "Triton is not installed. Using dummy decorators. "
            "Install it via `pip install triton` to enable kernel"
            " compilation.")

    def _dummy_decorator(self, name):

        def decorator(*args, **kwargs):
            if args and callable(args[0]):
                return args[0]
            return lambda f: f

        return decorator


class TritonLanguagePlaceholder(types.ModuleType):

    def __init__(self):
        super().__init__("triton.language")
        self.constexpr = None
        self.dtype = None
        self.int64 = None