importing.py 1.59 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import sys
import types
5
6
7
8
9
10
from importlib.util import find_spec

from vllm.logger import init_logger

logger = init_logger(__name__)

11
12
HAS_TRITON = (
    find_spec("triton") is not None
13
    or find_spec("pytorch-triton-xpu") is not None  # Not compatible
14
)
15
16

if not HAS_TRITON:
17
18
    logger.info("Triton not installed or not compatible; certain GPU-related"
                " functions will not be available.")
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    class TritonPlaceholder(types.ModuleType):

        def __init__(self):
            super().__init__("triton")
            self.jit = self._dummy_decorator("jit")
            self.autotune = self._dummy_decorator("autotune")
            self.heuristics = self._dummy_decorator("heuristics")
            self.language = TritonLanguagePlaceholder()
            logger.warning_once(
                "Triton is not installed. Using dummy decorators. "
                "Install it via `pip install triton` to enable kernel"
                "compilation.")

        def _dummy_decorator(self, name):

            def decorator(func=None, **kwargs):
                if func is None:
                    return lambda f: f
                return func

            return decorator

    class TritonLanguagePlaceholder(types.ModuleType):

        def __init__(self):
            super().__init__("triton.language")
            self.constexpr = None
            self.dtype = None

    sys.modules['triton'] = TritonPlaceholder()
    sys.modules['triton.language'] = TritonLanguagePlaceholder()

if 'triton' in sys.modules:
    logger.info("Triton module has been replaced with a placeholder.")