flashinfer.py 24.6 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
from vllm.platforms import current_platform
23
24
25

logger = init_logger(__name__)

26
27
28
29
30
31
32
33
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

34

35
36
37
38
39
40
41
42
43
44
45
@functools.cache
def has_flashinfer_cubin() -> bool:
    """Return `True` if flashinfer-cubin package is available."""
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True
    if importlib.util.find_spec("flashinfer_cubin") is not None:
        return True
    logger.debug_once("flashinfer-cubin package was not found")
    return False


46
47
@functools.cache
def has_flashinfer() -> bool:
48
    """Return `True` if flashinfer-python package is available."""
49
50
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
51
52
53
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
54
    # When not using flashinfer cubin,
55
    # Also check if nvcc is available since it's required to JIT compile flashinfer
56
    if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
57
58
59
60
        logger.debug_once(
            "FlashInfer unavailable since nvcc was not found "
            "and not using pre-downloaded cubins"
        )
61
62
        return False
    return True
63
64
65
66
67
68
69


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
70
71
        "https://github.com/flashinfer-ai/flashinfer"
    )
72
73
74
75
76
77
78
79
80
81
82


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
83
84
85
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
105
106
107
flashinfer_trtllm_bf16_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_bf16_moe"
)
108
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
109
110
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
111
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
112
113
114
115
116
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
117
118
119
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
    "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
)
120
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
121
122
123
124
125
126
127
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
    "flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
)
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
    "flashinfer", "scaled_fp4_grouped_quantize"
)
128
nvfp4_block_scale_interleave = _lazy_import_wrapper(
129
    "flashinfer.fp4_quantization", "block_scale_interleave"
130
)
131
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
132
133
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
134
135
136
137
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
138
139
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
140
_is_fi_autotuning: bool = False
141
142


143
144
@functools.cache
def has_flashinfer_comm() -> bool:
145
    """Return `True` if FlashInfer comm module is available."""
146
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
147
148
149


@functools.cache
150
def has_flashinfer_nvlink_two_sided() -> bool:
151
    """Return `True` if FlashInfer mnnvl all2all is available."""
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


170
171
172
173
174
175
176
177
@functools.cache
def has_flashinfer_nvlink_one_sided() -> bool:
    """Return `True` if FlashInfer trtllm_moe_alltoall module is available."""
    if not has_flashinfer_comm():
        return False
    return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None


178
179
@functools.cache
def has_flashinfer_moe() -> bool:
180
    """Return `True` if FlashInfer MoE module is available."""
181
182
183
184
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
185
186


187
188
189
190
191
192
193
194
@functools.cache
def has_flashinfer_cutedsl() -> bool:
    """Return ``True`` if FlashInfer cutedsl module is available."""
    return (
        has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
    )


195
196
197
198
199
200
201
202
203
@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
    """Return `True` if FlashInfer TRTLLM fused MoE is available."""
    if not has_flashinfer_moe():
        return False
    required_functions = [
        ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
        ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
204
        ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
205
206
207
208
209
210
211
212
    ]
    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


213
214
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
215
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
216
    if not has_flashinfer_moe():
217
218
219
220
221
222
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
223
        ("flashinfer", "nvfp4_block_scale_interleave"),
224
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
225
226
227
228
229
230
231
232
233
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


234
235
236
237
238
239
240
241
242
243
@functools.cache
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
    if not has_flashinfer_cutedsl():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
        ("flashinfer", "scaled_fp4_grouped_quantize"),
244
        ("flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"),
245
246
247
248
249
250
251
252
253
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


254
255
@functools.cache
def has_nvidia_artifactory() -> bool:
256
    """Return `True` if NVIDIA's artifactory is accessible.
257

258
259
260
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
261
262
    # If we have pre-downloaded cubins, we can assume the cubins are available.
    if has_flashinfer_cubin():
263
264
        return True

265
266
267
268
269
270
271
272
273
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
274
275
                response.status_code,
            )
276
277
278
279
280
281
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


282
@functools.cache
283
284
def supports_trtllm_attention() -> bool:
    """
285
286
    TRTLLM attention is supported if the platform is SM100,
    NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
287
    """
288
    # Batch-invariant mode disables TRTLLM attention
289
    if envs.VLLM_BATCH_INVARIANT:
290
291
        return False

292
293
294
295
    # TRTLLM attention is currently only validated on SM100 (CC 10.0).
    # SM103 (GB300) hangs with FlashInfer >= 0.6.7.
    # See: https://github.com/flashinfer-ai/flashinfer/issues/2939
    return current_platform.is_device_capability(100) and has_nvidia_artifactory()
296

297

298
def force_use_trtllm_attention() -> bool | None:
299
    """
300
301
    This function should only be called during initialization stage when vllm config
    is set.
302
    Return `None` if --attention-config.use_trtllm_attention is not set,
303
304
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
305
    """
306
307
308
309
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    return vllm_config.attention_config.use_trtllm_attention
310
311


312
313
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
314
315
    if force_use_trtllm_attention() is False:
        return False
316
    has_trtllm = supports_trtllm_attention()
317
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)
318
319


320
def use_trtllm_attention(
321
322
    num_qo_heads: int,
    num_kv_heads: int,
323
324
    num_tokens: int,
    max_seq_len: int,
325
    dcp_world_size: int,
326
    kv_cache_dtype: str,
327
    q_dtype: torch.dtype,
328
    is_prefill: bool,
329
330
    # None means auto-detection, True means force on, False means force off
    force_use_trtllm: bool | None = None,
331
    has_sinks: bool = False,
332
    has_spec: bool = False,
333
) -> bool:
334
    """Return `True` if TRTLLM attention is used."""
335

336
    # CLI argument is set to 0 - respect it
337
    if force_use_trtllm is not None and not force_use_trtllm:
338
339
        return False

340
341
342
343
344
345
346
347
    # Decode context parallel is not supported
    if dcp_world_size > 1:
        logger.warning_once(
            "Trtllm does not support returning LSE and as a result "
            "does not support DCP, reverting to FlashInfer"
        )
        return False

348
349
350
351
352
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
353
                "but --attention-config.use_trtllm_attention is set to 1"
354
            )
355
356
357
        return False

    # The combination of query and key heads is not supported
358
    if num_qo_heads % num_kv_heads != 0:
359
360
361
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
362
363
                "query and key heads, but --attention-config.use_trtllm_attention is "
                "set to 1"
364
            )
365
366
        return False

367
368
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
369
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
370
371
        return True

372
373
374
375
376
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

377
378
379
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
380
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
381
382
        return True

383
    if force_use_trtllm is None:
384
        # CLI argument not set - use auto-detection
385
386
        if is_prefill:
            # Prefill auto-detection
387
            use_trtllm = kv_cache_dtype == "auto"
388
389
390
391
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
392
            use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
393
394
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
395
396
        return use_trtllm

397
398
399
400
    # CLI argument is set to 1 - respect it
    logger.info_once(
        "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
    )
401
402
    return True

403

404
if has_flashinfer():
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    from vllm.utils.torch_utils import direct_register_custom_op

    def _flashinfer_concat_mla_k(
        k: torch.Tensor,
        k_nope: torch.Tensor,
        k_pe: torch.Tensor,
    ) -> None:
        """Custom op wrapper for flashinfer's concat_mla_k.

        This is an in-place operation that concatenates k_nope and k_pe into k.

        The kernel is optimized for DeepSeek V3 dimensions:
        - num_heads=128
        - nope_dim=128
        - rope_dim=64

        Key optimizations:
        - Warp-based processing with software pipelining
        - Vectorized memory access (int2 for nope, int for rope)
        - L2 prefetching for next row while processing current
        - Register reuse for rope values across all heads

        Args:
            k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim].
                Modified in-place.
            k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim].
            k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim].
                  This is broadcast to all heads.
        """
        from flashinfer.concat_ops import concat_mla_k

        concat_mla_k(k, k_nope, k_pe)

    def _flashinfer_concat_mla_k_fake(
        k: torch.Tensor,
        k_nope: torch.Tensor,
        k_pe: torch.Tensor,
    ) -> None:
        return

    # Register flashinfer concat_mla_k custom op
    direct_register_custom_op(
        op_name="flashinfer_concat_mla_k",
        op_func=_flashinfer_concat_mla_k,
        mutates_args=["k"],  # k tensor is modified in-place
        fake_impl=_flashinfer_concat_mla_k_fake,
    )
452
453
454
455
456
457
458
459
460
461
462
463
464

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
465
        use_8x4_sf_layout: bool,
466
467
468
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
469
470

        return flashinfer_mm_fp4_(
471
472
473
474
475
476
477
478
479
            A,
            B,
            A_scale,
            B_scale,
            g_scale,
            dtype,
            block_size=16,
            use_8x4_sf_layout=use_8x4_sf_layout,
            backend=backend,
480
481
482
483
484
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
485
486
487
488
489
490
491
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
492
        use_8x4_sf_layout: bool,
493
494
        backend: str,
    ) -> torch.Tensor:
495
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
511

512
513
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

514
515
516
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
517
518
519
520
521
522
523
524
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
525
526
527
528
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    @torch.library.custom_op(
        "vllm::flashinfer_nvfp4_quantize",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_nvfp4_quantize(
        a: torch.Tensor, a_global_sf: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        from flashinfer import SfLayout
        from flashinfer import nvfp4_quantize as nvfp4_quantize_

        return nvfp4_quantize_(
            a, a_global_sf, sfLayout=SfLayout.layout_8x4, do_shuffle=False
        )

    @torch.library.register_fake(
        "vllm::flashinfer_nvfp4_quantize",
    )
    def flashinfer_nvfp4_quantize_fake(
        a: torch.Tensor, a_global_sf: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        m, n = a.shape

        round_up = lambda x, y: (x + y - 1) // y * y

        rounded_m = round_up(m, 8)
        scale_n = n // 16
        rounded_n = round_up(scale_n, 4)

        return torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), torch.empty(
            rounded_m, rounded_n, dtype=torch.uint8, device=a.device
        )

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    @torch.library.custom_op(
        "vllm::mm_mxfp8",
        mutates_args=[],
        device_types="cuda",
    )
    def mm_mxfp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        out_dtype: torch.dtype,
        backend: str = "cutlass",
    ) -> torch.Tensor:
        from flashinfer import mm_mxfp8 as mm_mxfp8_

        return mm_mxfp8_(
            A,
            B,
            A_scale,
            B_scale,
            out=None,
            out_dtype=out_dtype,
            backend=backend,
        )

    @torch.library.register_fake(
        "vllm::mm_mxfp8",
    )
    def mm_mxfp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        out_dtype: torch.dtype,
        backend: str = "cutlass",
    ) -> torch.Tensor:
        # A is [m, k], B is [k, n] -> output [m, n]
        return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)


def flashinfer_mm_mxfp8(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str = "cutlass",
) -> torch.Tensor:
    """MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.

    Takes non-transposed weights and handles transpose internally.

    CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
    performance and accuracy. Both input and weight scales should be in
    swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
    """
    # a shape [M, K]
    # b shape [K, N]
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[1]  # K dimension must match

    if block_scale_b.ndim != 1:
        raise ValueError(
            "mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
            f"got shape={tuple(block_scale_b.shape)}"
        )

    # Output tensor [M, N]
    return mm_mxfp8(
        a,
        b.t(),  # Transpose weight: [N, K] -> [K, N]
        block_scale_a,
        block_scale_b,
        out_dtype,
        backend=backend,
    )

639
640
641
642
643
644
645
646
647
648

def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
649
650
651
652
653
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

654
    if backend in ("cutlass", "cudnn"):
655
656
657
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

658
659
    use_8x4_sf_layout = True if backend == "trtllm" and a.shape[0] <= 32 else False  # noqa: SIM210

660
661
662
663
664
665
666
    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
667
        use_8x4_sf_layout=use_8x4_sf_layout,
668
669
670
671
        backend=backend,
    )


672
def flashinfer_scaled_fp8_mm(
673
674
675
676
677
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
678
    bias: torch.Tensor | None = None,
679
) -> torch.Tensor:
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


702
703
704
705
706
707
def flashinfer_quant_nvfp4_8x4_sf_layout(
    a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    return flashinfer_nvfp4_quantize(a, a_global_sf)


708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
    "flashinfer.gemm", "fp8_blockscale_gemm_sm90"
)


@functools.cache
def has_flashinfer_fp8_blockscale_gemm() -> bool:
    """Return `True` if FlashInfer block-scale FP8 GEMM is available."""
    return (
        has_flashinfer()
        and current_platform.is_device_capability(90)
        and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90")
    )


@functools.cache
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
    """Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
    return (
        envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER
        and has_flashinfer_fp8_blockscale_gemm()
    )


def should_use_flashinfer_for_blockscale_fp8_gemm(
    is_flashinfer_supported: bool,
    output_dtype: torch.dtype,
    input: torch.Tensor,
    weight: torch.Tensor,
):
    if not is_flashinfer_supported:
        return False

    # Verify DeepGEMM N/K dims requirements
    # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
743
    # test inside kernels/quantization/test_block_fp8.py
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    N_MULTIPLE = 64
    K_MULTIPLE = 128

    weight_dtype = weight.dtype
    input_dtype = input.dtype

    should_use_flashinfer = (
        output_dtype == torch.bfloat16
        and input_dtype == torch.bfloat16
        and weight_dtype == torch.float8_e4m3fn
        and weight.shape[0] % N_MULTIPLE == 0
        and weight.shape[1] % K_MULTIPLE == 0
    )

    return should_use_flashinfer


761
762
__all__ = [
    "has_flashinfer",
763
    "flashinfer_trtllm_fp8_block_scale_moe",
764
    "flashinfer_cutlass_fused_moe",
765
    "flashinfer_cutedsl_grouped_gemm_nt_masked",
766
    "flashinfer_fp4_quantize",
767
768
    "silu_and_mul_scaled_nvfp4_experts_quantize",
    "scaled_fp4_grouped_quantize",
769
    "nvfp4_block_scale_interleave",
770
    "trtllm_fp4_block_scale_moe",
771
    "autotune",
772
    "has_flashinfer_moe",
773
    "has_flashinfer_comm",
774
775
    "has_flashinfer_nvlink_two_sided",
    "has_flashinfer_nvlink_one_sided",
776
    "has_flashinfer_cutlass_fused_moe",
777
    "has_flashinfer_cutedsl_grouped_gemm_nt_masked",
778
    "has_flashinfer_fp8_blockscale_gemm",
779
    "has_nvidia_artifactory",
780
    "supports_trtllm_attention",
781
    "can_use_trtllm_attention",
782
    "use_trtllm_attention",
783
    "flashinfer_scaled_fp4_mm",
784
    "flashinfer_scaled_fp8_mm",
785
    "flashinfer_quant_nvfp4_8x4_sf_layout",
786
787
788
    "flashinfer_fp8_blockscale_gemm",
    "should_use_flashinfer_for_blockscale_fp8_gemm",
    "is_flashinfer_fp8_blockscale_gemm_supported",
789
]