flashinfer.py 14.7 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    vllm_is_batch_invariant,
)
25
from vllm.platforms import current_platform
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

37
38
39

@functools.cache
def has_flashinfer() -> bool:
40
    """Return `True` if FlashInfer is available."""
41
42
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
43
44
45
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
46
    # When not using flashinfer cubin,
47
    # Also check if nvcc is available since it's required to JIT compile flashinfer
48
49
50
51
52
    if not envs.VLLM_HAS_FLASHINFER_CUBIN and shutil.which("nvcc") is None:
        logger.debug_once(
            "FlashInfer unavailable since nvcc was not found "
            "and not using pre-downloaded cubins"
        )
53
54
        return False
    return True
55
56
57
58
59
60
61


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
62
63
        "https://github.com/flashinfer-ai/flashinfer"
    )
64
65
66
67
68
69
70
71
72
73
74


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
75
76
77
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
97
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
98
99
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
100
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
101
102
103
104
105
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
106
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
107
nvfp4_block_scale_interleave = _lazy_import_wrapper(
108
109
    "flashinfer", "nvfp4_block_scale_interleave"
)
110
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
111
112
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
113
114
115
116
117

# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
118
119
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
120
121


122
123
@functools.cache
def has_flashinfer_comm() -> bool:
124
    """Return `True` if FlashInfer comm module is available."""
125
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
126
127
128
129


@functools.cache
def has_flashinfer_all2all() -> bool:
130
    """Return `True` if FlashInfer mnnvl all2all is available."""
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


149
150
@functools.cache
def has_flashinfer_moe() -> bool:
151
    """Return `True` if FlashInfer MoE module is available."""
152
153
154
155
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
156
157


158
159
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
160
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
161
    if not has_flashinfer_moe():
162
163
164
165
166
167
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
168
        ("flashinfer", "nvfp4_block_scale_interleave"),
169
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
170
171
172
173
174
175
176
177
178
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


179
180
@functools.cache
def has_nvidia_artifactory() -> bool:
181
    """Return `True` if NVIDIA's artifactory is accessible.
182

183
184
185
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
186
187
188
189
190
    # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
    # it's true, we could assume the cubins are available.
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True

191
192
193
194
195
196
197
198
199
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
200
201
                response.status_code,
            )
202
203
204
205
206
207
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


208
@functools.cache
209
210
211
212
213
def supports_trtllm_attention() -> bool:
    """
    TRTLLM attention is supported if the platform is SM100 and
    NVIDIA artifactory is accessible
    """
214
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
215
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
216

217
218

@functools.cache
219
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
220
    """Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
221
222
    if env_value is not None:
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
223
    return env_value
224

225

226
def force_use_trtllm_attention() -> bool | None:
227
    """
228
229
230
    Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
231
    """
232
233
234
    if vllm_is_batch_invariant():
        logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
        return False
235
    return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
236
237


238
239
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
240
241
    if force_use_trtllm_attention() is False:
        return False
242
243
244
245
    has_trtllm = supports_trtllm_attention()
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)


246
def use_trtllm_attention(
247
248
    num_qo_heads: int,
    num_kv_heads: int,
249
250
251
    num_tokens: int,
    max_seq_len: int,
    kv_cache_dtype: str,
252
    q_dtype: torch.dtype,
253
    is_prefill: bool,
254
    has_sinks: bool = False,
255
    has_spec: bool = False,
256
) -> bool:
257
    """Return `True` if TRTLLM attention is used."""
258
259
260
261
    force_use_trtllm = force_use_trtllm_attention()

    # Environment variable is set to 0 - respect it
    if force_use_trtllm is not None and not force_use_trtllm:
262
263
        return False

264
265
266
267
268
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
269
270
                "but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
271
272
273
        return False

    # The combination of query and key heads is not supported
274
    if num_qo_heads % num_kv_heads != 0:
275
276
277
278
279
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
                "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
            )
280
281
        return False

282
283
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
284
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
285
286
        return True

287
288
289
290
291
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

292
293
294
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
295
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
296
297
        return True

298
    if force_use_trtllm is None:
299
        # Environment variable not set - use auto-detection
300
301
302
303
304
305
306
307
308
309
310
311
        if is_prefill:
            # Prefill auto-detection
            use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto"
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
            use_trtllm = (
                num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto"
            )
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
312
313
        return use_trtllm

314
    # Environment variable is set to 1 - respect it
315
    logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
316
317
    return True

318

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
if has_flashinfer():

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
336
337
338
339
340
341
342
343

        return flashinfer_mm_fp4_(
            A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
344
345
346
347
348
349
350
351
352
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
353
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
369

370
371
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

372
373
374
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
375
376
377
378
379
380
381
382
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )


def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

    if backend == "cutlass":
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
        backend=backend,
    )


417
def flashinfer_scaled_fp8_mm(
418
419
420
421
422
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
423
    bias: torch.Tensor | None = None,
424
) -> torch.Tensor:
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


447
448
449
450
451
452
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
    """Cache result which only depends on the environment"""
    return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION


453
454
__all__ = [
    "has_flashinfer",
455
    "flashinfer_trtllm_fp8_block_scale_moe",
456
    "flashinfer_cutlass_fused_moe",
457
    "flashinfer_fp4_quantize",
458
    "nvfp4_block_scale_interleave",
459
    "trtllm_fp4_block_scale_moe",
460
    "autotune",
461
    "has_flashinfer_moe",
462
463
    "has_flashinfer_comm",
    "has_flashinfer_all2all",
464
    "has_flashinfer_cutlass_fused_moe",
465
    "has_nvidia_artifactory",
466
    "supports_trtllm_attention",
467
    "can_use_trtllm_attention",
468
    "use_trtllm_attention",
469
    "flashinfer_disable_q_quantization",
470
    "flashinfer_scaled_fp4_mm",
471
    "flashinfer_scaled_fp8_mm",
472
]