flashinfer.py 7.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import contextlib
import functools
import importlib
import importlib.util
13
14
import os
from typing import Any, Callable, NoReturn, Optional
15

16
17
18
import requests

import vllm.envs as envs
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
22
23

logger = init_logger(__name__)

24
25
26
27
28
29
30
31
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

@functools.cache
def has_flashinfer() -> bool:
    """Return ``True`` if FlashInfer is available."""
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
    return importlib.util.find_spec("flashinfer") is not None


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
        "https://github.com/flashinfer-ai/flashinfer")


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
def _lazy_import_wrapper(module_name: str,
                         attr_name: str,
                         fallback_fn: Callable[..., Any] = _missing):
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
80
81
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
82
83
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe")
84
85
86
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
                                                    "cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
87
88
nvfp4_block_scale_interleave = _lazy_import_wrapper(
    "flashinfer", "nvfp4_block_scale_interleave")
89
90
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
    "flashinfer", "trtllm_fp4_block_scale_moe")
91
92
93
94
95
96
97
98

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())


99
100
101
@functools.cache
def has_flashinfer_moe() -> bool:
    """Return ``True`` if FlashInfer MoE module is available."""
102
103
    return has_flashinfer() and importlib.util.find_spec(
        "flashinfer.fused_moe") is not None
104
105


106
107
108
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
109
    if not has_flashinfer_moe():
110
111
112
113
114
115
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
116
        ("flashinfer", "nvfp4_block_scale_interleave"),
117
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
118
119
120
121
122
123
124
125
126
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


127
128
129
@functools.cache
def has_nvidia_artifactory() -> bool:
    """Return ``True`` if NVIDIA's artifactory is accessible.
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
                response.status_code)
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


150
def use_trtllm_attention(
151
152
153
154
155
156
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
    num_qo_heads: Optional[int],
    num_kv_heads: Optional[int],
    attn_head_size: Optional[int],
157
    has_sinks: bool = False,
158
159
160
161
162
163
164
165
) -> bool:
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
    if not (current_platform.is_device_capability(100)
            and has_nvidia_artifactory()):
        return False

    # Check if the dimensions are supported by TRTLLM decode attention
    if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
166
            or num_qo_heads % num_kv_heads != 0):
167
168
        return False

169
170
171
172
173
174
175
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
        logger.info_once(
            "Using TRTLLM attention (required for attention sinks).")
        return True

176
    env_value = envs.VLLM_USE_TRTLLM_ATTENTION
177
    if env_value is not None:
178
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
179
180
181
182
        # Environment variable is set - respect it
        # Making the conditional check for zero because
        # the path is automatically enabled if the batch size condition
        # is satisfied.
183
184
        use_trtllm = (env_value == "1")
        if use_trtllm:
185
            logger.info_once("Using TRTLLM attention.")
186
        return use_trtllm
187
188
189
190
191
    else:
        # Environment variable not set - use auto-detection
        use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
                      and kv_cache_dtype == "auto")
        if use_trtllm:
192
            logger.warning_once("Using TRTLLM attention (auto-detected).")
193
194
195
        return use_trtllm


196
197
__all__ = [
    "has_flashinfer",
198
    "flashinfer_trtllm_fp8_block_scale_moe",
199
200
    "flashinfer_cutlass_fused_moe",
    "fp4_quantize",
201
    "nvfp4_block_scale_interleave",
202
    "trtllm_fp4_block_scale_moe",
203
    "autotune",
204
205
    "has_flashinfer_moe",
    "has_flashinfer_cutlass_fused_moe",
206
    "has_nvidia_artifactory",
207
    "use_trtllm_attention",
208
]