flashinfer.py 16.5 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    vllm_is_batch_invariant,
)
25
from vllm.platforms import current_platform
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

37

38
39
40
41
42
43
44
45
46
47
48
@functools.cache
def has_flashinfer_cubin() -> bool:
    """Return `True` if flashinfer-cubin package is available."""
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True
    if importlib.util.find_spec("flashinfer_cubin") is not None:
        return True
    logger.debug_once("flashinfer-cubin package was not found")
    return False


49
50
@functools.cache
def has_flashinfer() -> bool:
51
    """Return `True` if flashinfer-python package is available."""
52
53
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
54
55
56
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
57
    # When not using flashinfer cubin,
58
    # Also check if nvcc is available since it's required to JIT compile flashinfer
59
    if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
60
61
62
63
        logger.debug_once(
            "FlashInfer unavailable since nvcc was not found "
            "and not using pre-downloaded cubins"
        )
64
65
        return False
    return True
66
67
68
69
70
71
72


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
73
74
        "https://github.com/flashinfer-ai/flashinfer"
    )
75
76
77
78
79
80
81
82
83
84
85


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
86
87
88
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
108
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
109
110
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
111
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
112
113
114
115
116
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
117
118
119
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
    "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
)
120
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
121
122
123
124
125
126
127
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
    "flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
)
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
    "flashinfer", "scaled_fp4_grouped_quantize"
)
128
nvfp4_block_scale_interleave = _lazy_import_wrapper(
129
130
    "flashinfer", "nvfp4_block_scale_interleave"
)
131
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
132
133
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
134
135
136
137
138

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
139
140
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
141
142


143
144
@functools.cache
def has_flashinfer_comm() -> bool:
145
    """Return `True` if FlashInfer comm module is available."""
146
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
147
148
149
150


@functools.cache
def has_flashinfer_all2all() -> bool:
151
    """Return `True` if FlashInfer mnnvl all2all is available."""
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


170
171
@functools.cache
def has_flashinfer_moe() -> bool:
172
    """Return `True` if FlashInfer MoE module is available."""
173
174
175
176
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
177
178


179
180
181
182
183
184
185
186
@functools.cache
def has_flashinfer_cutedsl() -> bool:
    """Return ``True`` if FlashInfer cutedsl module is available."""
    return (
        has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
    )


187
188
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
189
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
190
    if not has_flashinfer_moe():
191
192
193
194
195
196
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
197
        ("flashinfer", "nvfp4_block_scale_interleave"),
198
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
199
200
201
202
203
204
205
206
207
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@functools.cache
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
    if not has_flashinfer_cutedsl():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
        ("flashinfer", "scaled_fp4_grouped_quantize"),
        ("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


228
229
@functools.cache
def has_nvidia_artifactory() -> bool:
230
    """Return `True` if NVIDIA's artifactory is accessible.
231

232
233
234
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
235
236
    # If we have pre-downloaded cubins, we can assume the cubins are available.
    if has_flashinfer_cubin():
237
238
        return True

239
240
241
242
243
244
245
246
247
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
248
249
                response.status_code,
            )
250
251
252
253
254
255
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


256
@functools.cache
257
258
def supports_trtllm_attention() -> bool:
    """
259
260
    TRTLLM attention is supported if the platform is SM100,
    NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
261
    """
262
263
264
265
    # Batch-invariant mode disables TRTLLM attention
    if vllm_is_batch_invariant():
        return False

266
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
267
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
268

269

270
def force_use_trtllm_attention() -> bool | None:
271
    """
272
273
    This function should only be called during initialization stage when vllm config
    is set.
274
    Return `None` if --attention-config.use_trtllm_attention is not set,
275
276
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
277
    """
278
279
280
281
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    return vllm_config.attention_config.use_trtllm_attention
282
283


284
285
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
286
287
    if force_use_trtllm_attention() is False:
        return False
288
289
290
291
    has_trtllm = supports_trtllm_attention()
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)


292
def use_trtllm_attention(
293
294
    num_qo_heads: int,
    num_kv_heads: int,
295
296
    num_tokens: int,
    max_seq_len: int,
297
    dcp_world_size: int,
298
    kv_cache_dtype: str,
299
    q_dtype: torch.dtype,
300
    is_prefill: bool,
301
302
    # None means auto-detection, True means force on, False means force off
    force_use_trtllm: bool | None = None,
303
    has_sinks: bool = False,
304
    has_spec: bool = False,
305
) -> bool:
306
    """Return `True` if TRTLLM attention is used."""
307

308
    # CLI argument is set to 0 - respect it
309
    if force_use_trtllm is not None and not force_use_trtllm:
310
311
        return False

312
313
314
315
316
317
318
319
    # Decode context parallel is not supported
    if dcp_world_size > 1:
        logger.warning_once(
            "Trtllm does not support returning LSE and as a result "
            "does not support DCP, reverting to FlashInfer"
        )
        return False

320
321
322
323
324
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
325
                "but --attention-config.use_trtllm_attention is set to 1"
326
            )
327
328
329
        return False

    # The combination of query and key heads is not supported
330
    if num_qo_heads % num_kv_heads != 0:
331
332
333
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
334
335
                "query and key heads, but --attention-config.use_trtllm_attention is "
                "set to 1"
336
            )
337
338
        return False

339
340
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
341
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
342
343
        return True

344
345
346
347
348
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

349
350
351
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
352
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
353
354
        return True

355
    if force_use_trtllm is None:
356
        # CLI argument not set - use auto-detection
357
358
        if is_prefill:
            # Prefill auto-detection
359
            use_trtllm = kv_cache_dtype == "auto"
360
361
362
363
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
364
            use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
365
366
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
367
368
        return use_trtllm

369
370
371
372
    # CLI argument is set to 1 - respect it
    logger.info_once(
        "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
    )
373
374
    return True

375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
393
394
395
396
397
398
399
400

        return flashinfer_mm_fp4_(
            A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
401
402
403
404
405
406
407
408
409
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
410
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
411

412
413
414
415
416
417
418
419
420
421
422
423
424
425
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
426

427
428
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

429
430
431
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
432
433
434
435
436
437
438
439
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )


def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


474
def flashinfer_scaled_fp8_mm(
475
476
477
478
479
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
480
    bias: torch.Tensor | None = None,
481
) -> torch.Tensor:
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


504
505
__all__ = [
    "has_flashinfer",
506
    "flashinfer_trtllm_fp8_block_scale_moe",
507
    "flashinfer_cutlass_fused_moe",
508
    "flashinfer_cutedsl_grouped_gemm_nt_masked",
509
    "flashinfer_fp4_quantize",
510
511
    "silu_and_mul_scaled_nvfp4_experts_quantize",
    "scaled_fp4_grouped_quantize",
512
    "nvfp4_block_scale_interleave",
513
    "trtllm_fp4_block_scale_moe",
514
    "autotune",
515
    "has_flashinfer_moe",
516
517
    "has_flashinfer_comm",
    "has_flashinfer_all2all",
518
    "has_flashinfer_cutlass_fused_moe",
519
    "has_flashinfer_cutedsl_grouped_gemm_nt_masked",
520
    "has_nvidia_artifactory",
521
    "supports_trtllm_attention",
522
    "can_use_trtllm_attention",
523
    "use_trtllm_attention",
524
    "flashinfer_scaled_fp4_mm",
525
    "flashinfer_scaled_fp8_mm",
526
]