flashinfer.py 15 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    vllm_is_batch_invariant,
)
25
from vllm.platforms import current_platform
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

37

38
39
40
41
42
43
44
45
46
47
48
@functools.cache
def has_flashinfer_cubin() -> bool:
    """Return `True` if flashinfer-cubin package is available."""
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True
    if importlib.util.find_spec("flashinfer_cubin") is not None:
        return True
    logger.debug_once("flashinfer-cubin package was not found")
    return False


49
50
@functools.cache
def has_flashinfer() -> bool:
51
    """Return `True` if flashinfer-python package is available."""
52
53
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
54
55
56
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
57
    # When not using flashinfer cubin,
58
    # Also check if nvcc is available since it's required to JIT compile flashinfer
59
    if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
60
61
62
63
        logger.debug_once(
            "FlashInfer unavailable since nvcc was not found "
            "and not using pre-downloaded cubins"
        )
64
65
        return False
    return True
66
67
68
69
70
71
72


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
73
74
        "https://github.com/flashinfer-ai/flashinfer"
    )
75
76
77
78
79
80
81
82
83
84
85


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
86
87
88
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
108
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
109
110
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
111
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
112
113
114
115
116
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
117
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
118
nvfp4_block_scale_interleave = _lazy_import_wrapper(
119
120
    "flashinfer", "nvfp4_block_scale_interleave"
)
121
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
122
123
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
124
125
126
127
128

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
129
130
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
131
132


133
134
@functools.cache
def has_flashinfer_comm() -> bool:
135
    """Return `True` if FlashInfer comm module is available."""
136
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
137
138
139
140


@functools.cache
def has_flashinfer_all2all() -> bool:
141
    """Return `True` if FlashInfer mnnvl all2all is available."""
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


160
161
@functools.cache
def has_flashinfer_moe() -> bool:
162
    """Return `True` if FlashInfer MoE module is available."""
163
164
165
166
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
167
168


169
170
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
171
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
172
    if not has_flashinfer_moe():
173
174
175
176
177
178
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
179
        ("flashinfer", "nvfp4_block_scale_interleave"),
180
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
181
182
183
184
185
186
187
188
189
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


190
191
@functools.cache
def has_nvidia_artifactory() -> bool:
192
    """Return `True` if NVIDIA's artifactory is accessible.
193

194
195
196
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
197
198
    # If we have pre-downloaded cubins, we can assume the cubins are available.
    if has_flashinfer_cubin():
199
200
        return True

201
202
203
204
205
206
207
208
209
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
210
211
                response.status_code,
            )
212
213
214
215
216
217
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


218
@functools.cache
219
220
def supports_trtllm_attention() -> bool:
    """
221
222
    TRTLLM attention is supported if the platform is SM100,
    NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
223
    """
224
225
226
227
    # Batch-invariant mode disables TRTLLM attention
    if vllm_is_batch_invariant():
        return False

228
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
229
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
230

231
232

@functools.cache
233
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
234
    """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
235
236
    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
237
    return env_value
238

239

240
def force_use_trtllm_attention() -> bool | None:
241
    """
242
243
244
    Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
245
246
    """
    return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
247
248


249
250
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
251
252
    if force_use_trtllm_attention() is False:
        return False
253
254
255
256
    has_trtllm = supports_trtllm_attention()
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)


257
def use_trtllm_attention(
258
259
    num_qo_heads: int,
    num_kv_heads: int,
260
261
262
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
263
    q_dtype: torch.dtype,
264
    is_prefill: bool,
265
    has_sinks: bool = False,
266
    has_spec: bool = False,
267
) -> bool:
268
    """Return `True` if TRTLLM attention is used."""
269
270
271
272
    force_use_trtllm = force_use_trtllm_attention()

    # Environment variable is set to 0 - respect it
    if force_use_trtllm is not None and not force_use_trtllm:
273
274
        return False

275
276
277
278
279
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
280
281
                "but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
282
283
284
        return False

    # The combination of query and key heads is not supported
285
    if num_qo_heads % num_kv_heads != 0:
286
287
288
289
290
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
                "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
291
292
        return False

293
294
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
295
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
296
297
        return True

298
299
300
301
302
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

303
304
305
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
306
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
307
308
        return True

309
    if force_use_trtllm is None:
310
        # Environment variable not set - use auto-detection
311
312
313
314
315
316
317
318
319
320
321
322
        if is_prefill:
            # Prefill auto-detection
            use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
            use_trtllm = (
                num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
            )
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
323
324
        return use_trtllm

325
    # Environment variable is set to 1 - respect it
326
    logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
327
328
    return True

329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
347
348
349
350
351
352
353
354

        return flashinfer_mm_fp4_(
            A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
355
356
357
358
359
360
361
362
363
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
364
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
365

366
367
368
369
370
371
372
373
374
375
376
377
378
379
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
380

381
382
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

383
384
385
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
386
387
388
389
390
391
392
393
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )


def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


428
def flashinfer_scaled_fp8_mm(
429
430
431
432
433
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
434
    bias: torch.Tensor | None = None,
435
) -> torch.Tensor:
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


458
459
460
461
462
463
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
    """Cache result which only depends on the environment"""
    return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION


464
465
__all__ = [
    "has_flashinfer",
466
    "flashinfer_trtllm_fp8_block_scale_moe",
467
    "flashinfer_cutlass_fused_moe",
468
    "flashinfer_fp4_quantize",
469
    "nvfp4_block_scale_interleave",
470
    "trtllm_fp4_block_scale_moe",
471
    "autotune",
472
    "has_flashinfer_moe",
473
474
    "has_flashinfer_comm",
    "has_flashinfer_all2all",
475
    "has_flashinfer_cutlass_fused_moe",
476
    "has_nvidia_artifactory",
477
    "supports_trtllm_attention",
478
    "can_use_trtllm_attention",
479
    "use_trtllm_attention",
480
    "flashinfer_disable_q_quantization",
481
    "flashinfer_scaled_fp4_mm",
482
    "flashinfer_scaled_fp8_mm",
483
]