flashinfer.py 14.5 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    vllm_is_batch_invariant,
)
25
from vllm.platforms import current_platform
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

37
38
39

@functools.cache
def has_flashinfer() -> bool:
40
    """Return `True` if FlashInfer is available."""
41
42
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
43
44
45
46
47
48
49
50
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
    # Also check if nvcc is available since it's required to JIT compile flashinfer
    if shutil.which("nvcc") is None:
        logger.debug_once("FlashInfer unavailable since nvcc was not found")
        return False
    return True
51
52
53
54
55
56
57


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
58
59
        "https://github.com/flashinfer-ai/flashinfer"
    )
60
61
62
63
64
65
66
67
68
69
70


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
71
72
73
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
93
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
94
95
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
96
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
97
98
99
100
101
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
102
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
103
nvfp4_block_scale_interleave = _lazy_import_wrapper(
104
105
    "flashinfer", "nvfp4_block_scale_interleave"
)
106
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
107
108
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
109
110
111
112
113

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
114
115
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
116
117


118
119
@functools.cache
def has_flashinfer_comm() -> bool:
120
    """Return `True` if FlashInfer comm module is available."""
121
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
122
123
124
125


@functools.cache
def has_flashinfer_all2all() -> bool:
126
    """Return `True` if FlashInfer mnnvl all2all is available."""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


145
146
@functools.cache
def has_flashinfer_moe() -> bool:
147
    """Return `True` if FlashInfer MoE module is available."""
148
149
150
151
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
152
153


154
155
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
156
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
157
    if not has_flashinfer_moe():
158
159
160
161
162
163
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
164
        ("flashinfer", "nvfp4_block_scale_interleave"),
165
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
166
167
168
169
170
171
172
173
174
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


175
176
@functools.cache
def has_nvidia_artifactory() -> bool:
177
    """Return `True` if NVIDIA's artifactory is accessible.
178

179
180
181
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
182
183
184
185
186
    # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
    # it's true, we could assume the cubins are available.
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True

187
188
189
190
191
192
193
194
195
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
196
197
                response.status_code,
            )
198
199
200
201
202
203
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


204
@functools.cache
205
206
207
208
209
def supports_trtllm_attention() -> bool:
    """
    TRTLLM attention is supported if the platform is SM100 and
    NVIDIA artifactory is accessible
    """
210
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
211
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
212

213
214

@functools.cache
215
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
216
    """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
217
218
    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
219
    return env_value
220

221

222
def force_use_trtllm_attention() -> bool | None:
223
    """
224
225
226
    Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
227
    """
228
229
230
    if vllm_is_batch_invariant():
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
        return False
231
    return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
232
233


234
235
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
236
237
    if force_use_trtllm_attention() is False:
        return False
238
239
240
241
    has_trtllm = supports_trtllm_attention()
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)


242
def use_trtllm_attention(
243
244
    num_qo_heads: int,
    num_kv_heads: int,
245
246
247
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
248
    q_dtype: torch.dtype,
249
    is_prefill: bool,
250
    has_sinks: bool = False,
251
    has_spec: bool = False,
252
) -> bool:
253
    """Return `True` if TRTLLM attention is used."""
254
255
256
257
    force_use_trtllm = force_use_trtllm_attention()

    # Environment variable is set to 0 - respect it
    if force_use_trtllm is not None and not force_use_trtllm:
258
259
        return False

260
261
262
263
264
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
265
266
                "but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
267
268
269
        return False

    # The combination of query and key heads is not supported
270
    if num_qo_heads % num_kv_heads != 0:
271
272
273
274
275
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
                "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
276
277
        return False

278
279
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
280
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
281
282
        return True

283
284
285
286
287
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

288
289
290
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
291
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
292
293
        return True

294
    if force_use_trtllm is None:
295
        # Environment variable not set - use auto-detection
296
297
298
299
300
301
302
303
304
305
306
307
        if is_prefill:
            # Prefill auto-detection
            use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
            use_trtllm = (
                num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
            )
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
308
309
        return use_trtllm

310
    # Environment variable is set to 1 - respect it
311
    logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
312
313
    return True

314

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
332
333
334
335
336
337
338
339

        return flashinfer_mm_fp4_(
            A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
340
341
342
343
344
345
346
347
348
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
349
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
350

351
352
353
354
355
356
357
358
359
360
361
362
363
364
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
365

366
367
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

368
369
370
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
371
372
373
374
375
376
377
378
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )


def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


413
def flashinfer_scaled_fp8_mm(
414
415
416
417
418
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
419
    bias: torch.Tensor | None = None,
420
) -> torch.Tensor:
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


443
444
445
446
447
448
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
    """Cache result which only depends on the environment"""
    return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION


449
450
__all__ = [
    "has_flashinfer",
451
    "flashinfer_trtllm_fp8_block_scale_moe",
452
    "flashinfer_cutlass_fused_moe",
453
    "flashinfer_fp4_quantize",
454
    "nvfp4_block_scale_interleave",
455
    "trtllm_fp4_block_scale_moe",
456
    "autotune",
457
    "has_flashinfer_moe",
458
459
    "has_flashinfer_comm",
    "has_flashinfer_all2all",
460
    "has_flashinfer_cutlass_fused_moe",
461
    "has_nvidia_artifactory",
462
    "supports_trtllm_attention",
463
    "can_use_trtllm_attention",
464
    "use_trtllm_attention",
465
    "flashinfer_disable_q_quantization",
466
    "flashinfer_scaled_fp4_mm",
467
    "flashinfer_scaled_fp8_mm",
468
]