flashinfer.py 13.9 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
12
13
from __future__ import annotations

import contextlib
import functools
import importlib
import importlib.util
14
15
import os
from typing import Any, Callable, NoReturn, Optional
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
from vllm.platforms import current_platform
23
24
25

logger = init_logger(__name__)

26
27
28
29
30
31
32
33
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

34
35
36
37
38
39
40
41
42
43
44
45
46
47

@functools.cache
def has_flashinfer() -> bool:
    """Return ``True`` if FlashInfer is available."""
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
    return importlib.util.find_spec("flashinfer") is not None


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
48
49
        "https://github.com/flashinfer-ai/flashinfer"
    )
50
51
52
53
54
55
56
57
58
59
60


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
61
62
63
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
83
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
84
85
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
86
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
87
88
89
90
91
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
92
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
93
nvfp4_block_scale_interleave = _lazy_import_wrapper(
94
95
    "flashinfer", "nvfp4_block_scale_interleave"
)
96
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
97
98
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
99
100
101
102
103

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
104
105
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
106
107


108
109
110
@functools.cache
def has_flashinfer_comm() -> bool:
    """Return ``True`` if FlashInfer comm module is available."""
111
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134


@functools.cache
def has_flashinfer_all2all() -> bool:
    """Return ``True`` if FlashInfer mnnvl all2all is available."""
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


135
136
137
@functools.cache
def has_flashinfer_moe() -> bool:
    """Return ``True`` if FlashInfer MoE module is available."""
138
139
140
141
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
142
143


144
145
146
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
147
    if not has_flashinfer_moe():
148
149
150
151
152
153
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
154
        ("flashinfer", "nvfp4_block_scale_interleave"),
155
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
156
157
158
159
160
161
162
163
164
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


165
166
167
@functools.cache
def has_nvidia_artifactory() -> bool:
    """Return ``True`` if NVIDIA's artifactory is accessible.
168

169
170
171
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
172
173
174
175
176
    # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
    # it's true, we could assume the cubins are available.
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True

177
178
179
180
181
182
183
184
185
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
186
187
                response.status_code,
            )
188
189
190
191
192
193
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


194
@functools.cache
195
196
197
198
199
def supports_trtllm_attention() -> bool:
    """
    TRTLLM attention is supported if the platform is SM100 and
    NVIDIA artifactory is accessible
    """
200
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
201
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
202

203
204
205
206

@functools.cache
def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]:
    """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
207
208
    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
209
    return env_value
210

211
212
213
214
215
216
217
218

def force_use_trtllm_attention() -> Optional[bool]:
    """
    Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set,
    return ``True`` if TRTLLM attention is forced to be used,
    return ``False`` if TRTLLM attention is forced to be not used.
    """
    return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
219
220


221
222
223
224
225
226
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
    has_trtllm = supports_trtllm_attention()
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)


227
def use_trtllm_attention(
228
229
    num_qo_heads: int,
    num_kv_heads: int,
230
231
232
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
233
    q_dtype: torch.dtype,
234
    is_prefill: bool,
235
    has_sinks: bool = False,
236
    has_spec: bool = False,
237
) -> bool:
238
239
240
241
242
    """Return ``True`` if TRTLLM attention is used."""
    force_use_trtllm = force_use_trtllm_attention()

    # Environment variable is set to 0 - respect it
    if force_use_trtllm is not None and not force_use_trtllm:
243
244
        return False

245
246
247
248
249
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
250
251
                "but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
252
253
254
        return False

    # The combination of query and key heads is not supported
255
    if num_qo_heads % num_kv_heads != 0:
256
257
258
259
260
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
                "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
261
262
        return False

263
264
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
265
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
266
267
        return True

268
269
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
270
271
272
        if has_sinks:
            raise RuntimeError(
                "TRTLLM FP8-qkv kernel is not supported for attention sinks. "
273
274
                "Use kv_cache_dtype=auto for now."
            )
275
276
277
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

278
279
280
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
281
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
282
283
        return True

284
    if force_use_trtllm is None:
285
        # Environment variable not set - use auto-detection
286
287
288
        use_trtllm = (
            num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
        )
289
        if use_trtllm:
290
            logger.warning_once("Using TRTLLM attention (auto-detected).")
291
292
        return use_trtllm

293
    # Environment variable is set to 1 - respect it
294
    logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
295
296
    return True

297

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
315
316
317
318
319
320
321
322

        return flashinfer_mm_fp4_(
            A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
323
324
325
326
327
328
329
330
331
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
332
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
348

349
350
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

351
352
353
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
354
355
356
357
358
359
360
361
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )


def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]
    assert block_scale_a.shape[1] == a.shape[1] // 8
    assert block_scale_b.shape[1] == b.shape[1] // 8

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


398
def flashinfer_scaled_fp8_mm(
399
400
401
402
403
404
405
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


428
429
430
431
432
433
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
    """Cache result which only depends on the environment"""
    return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION


434
435
__all__ = [
    "has_flashinfer",
436
    "flashinfer_trtllm_fp8_block_scale_moe",
437
438
    "flashinfer_cutlass_fused_moe",
    "fp4_quantize",
439
    "nvfp4_block_scale_interleave",
440
    "trtllm_fp4_block_scale_moe",
441
    "autotune",
442
    "has_flashinfer_moe",
443
444
    "has_flashinfer_comm",
    "has_flashinfer_all2all",
445
    "has_flashinfer_cutlass_fused_moe",
446
    "has_nvidia_artifactory",
447
    "supports_trtllm_attention",
448
    "can_use_trtllm_attention",
449
    "use_trtllm_attention",
450
    "flashinfer_disable_q_quantization",
451
    "flashinfer_scaled_fp4_mm",
452
    "flashinfer_scaled_fp8_mm",
453
]