flashinfer.py 9.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import contextlib
import functools
import importlib
import importlib.util
13
14
import os
from typing import Any, Callable, NoReturn, Optional
15

16
import requests
17
import torch
18
19

import vllm.envs as envs
20
from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22
23
24

logger = init_logger(__name__)

25
26
27
28
29
30
31
32
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

@functools.cache
def has_flashinfer() -> bool:
    """Return ``True`` if FlashInfer is available."""
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
    return importlib.util.find_spec("flashinfer") is not None


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
        "https://github.com/flashinfer-ai/flashinfer")


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
def _lazy_import_wrapper(module_name: str,
                         attr_name: str,
                         fallback_fn: Callable[..., Any] = _missing):
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
81
82
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
83
84
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe")
85
86
87
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
                                                    "cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
88
89
nvfp4_block_scale_interleave = _lazy_import_wrapper(
    "flashinfer", "nvfp4_block_scale_interleave")
90
91
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
    "flashinfer", "trtllm_fp4_block_scale_moe")
92
93
94
95
96
97
98
99

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())


100
101
102
@functools.cache
def has_flashinfer_moe() -> bool:
    """Return ``True`` if FlashInfer MoE module is available."""
103
104
    return has_flashinfer() and importlib.util.find_spec(
        "flashinfer.fused_moe") is not None
105
106


107
108
109
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
110
    if not has_flashinfer_moe():
111
112
113
114
115
116
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
117
        ("flashinfer", "nvfp4_block_scale_interleave"),
118
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
119
120
121
122
123
124
125
126
127
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


128
129
130
@functools.cache
def has_nvidia_artifactory() -> bool:
    """Return ``True`` if NVIDIA's artifactory is accessible.
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
                response.status_code)
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@functools.cache
def supports_trtllm_attention() -> tuple[bool, Optional[str]]:
    """Cache result which only depends on the environment"""
    # This is a lambda, call it once
    env_value = envs.VLLM_USE_TRTLLM_ATTENTION

    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
    if not (current_platform.is_device_capability(100)
            and has_nvidia_artifactory()):
        return False, env_value

    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
        # Environment variable is set - respect it
        # Making the conditional check for zero because
        # the path is automatically enabled if the batch size condition
        # is satisfied.
        use_trtllm = (env_value == "1")
        if use_trtllm:
            logger.info_once("Using TRTLLM attention.")
        return use_trtllm, env_value

    return True, None


176
def use_trtllm_attention(
177
178
179
180
181
182
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
    num_qo_heads: Optional[int],
    num_kv_heads: Optional[int],
    attn_head_size: Optional[int],
183
    has_sinks: bool = False,
184
) -> bool:
185
186
    use_trtllm, env_value = supports_trtllm_attention()
    if not use_trtllm:
187
188
189
190
        return False

    # Check if the dimensions are supported by TRTLLM decode attention
    if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
191
            or num_qo_heads % num_kv_heads != 0):
192
193
        return False

194
195
196
197
198
199
200
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
        logger.info_once(
            "Using TRTLLM attention (required for attention sinks).")
        return True

201
    if env_value is None:
202
203
204
205
        # Environment variable not set - use auto-detection
        use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
                      and kv_cache_dtype == "auto")
        if use_trtllm:
206
            logger.warning_once("Using TRTLLM attention (auto-detected).")
207
208
        return use_trtllm

209
210
211
    # Environment variable is set to 1 - respect it
    return True

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
        return flashinfer_mm_fp4_(A,
                                  B,
                                  A_scale,
                                  B_scale,
                                  g_scale,
                                  dtype,
                                  block_size=16,
                                  backend=backend)

    @torch.library.register_fake("vllm::flashinfer_mm_fp4", )
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        return torch.empty(A.shape[0],
                           B.shape[1],
                           dtype=dtype,
                           device=A.device)


def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
                             block_scale_a: torch.Tensor,
                             block_scale_b: torch.Tensor, alpha: torch.Tensor,
                             out_dtype: torch.dtype,
                             backend: str) -> torch.Tensor:
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]
    assert block_scale_a.shape[1] == a.shape[1] // 8
    assert block_scale_b.shape[1] == b.shape[1] // 8

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


282
283
__all__ = [
    "has_flashinfer",
284
    "flashinfer_trtllm_fp8_block_scale_moe",
285
286
    "flashinfer_cutlass_fused_moe",
    "fp4_quantize",
287
    "nvfp4_block_scale_interleave",
288
    "trtllm_fp4_block_scale_moe",
289
    "autotune",
290
291
    "has_flashinfer_moe",
    "has_flashinfer_cutlass_fused_moe",
292
    "has_nvidia_artifactory",
293
    "use_trtllm_attention",
294
    "flashinfer_scaled_fp4_mm",
295
]