flashinfer.py 24.6 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
import contextlib
import functools
import importlib
import importlib.util
12
import os
13
import shutil
14
15
from collections.abc import Callable
from typing import Any, NoReturn
16

17
import requests
18
import torch
19
20

import vllm.envs as envs
21
from vllm.logger import init_logger
22
23
24
from vllm.model_executor.layers.batch_invariant import (
    vllm_is_batch_invariant,
)
25
from vllm.platforms import current_platform
26
27
28

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35  # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
    "FLASHINFER_CUBINS_REPOSITORY",
    "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/",  # noqa: E501
)

37

38
39
40
41
42
43
44
45
46
47
48
@functools.cache
def has_flashinfer_cubin() -> bool:
    """Return `True` if flashinfer-cubin package is available."""
    if envs.VLLM_HAS_FLASHINFER_CUBIN:
        return True
    if importlib.util.find_spec("flashinfer_cubin") is not None:
        return True
    logger.debug_once("flashinfer-cubin package was not found")
    return False


49
50
@functools.cache
def has_flashinfer() -> bool:
51
    """Return `True` if flashinfer-python package is available."""
52
53
    # Use find_spec to check if the module exists without importing it
    # This avoids potential CUDA initialization side effects
54
55
56
    if importlib.util.find_spec("flashinfer") is None:
        logger.debug_once("FlashInfer unavailable since package was not found")
        return False
57
    # When not using flashinfer cubin,
58
    # Also check if nvcc is available since it's required to JIT compile flashinfer
59
    if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
60
61
62
63
        logger.debug_once(
            "FlashInfer unavailable since nvcc was not found "
            "and not using pre-downloaded cubins"
        )
64
65
        return False
    return True
66
67
68
69
70
71
72


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable FlashInfer backend."""
    raise RuntimeError(
        "FlashInfer backend is not available. Please install the package "
        "to enable FlashInfer kernels: "
73
74
        "https://github.com/flashinfer-ai/flashinfer"
    )
75
76
77
78
79
80
81
82
83
84
85


def _get_submodule(module_name: str) -> Any | None:
    """Safely import a submodule and return it, or None if not available."""
    try:
        return importlib.import_module(module_name)
    except (ImportError, ModuleNotFoundError):
        return None


# General lazy import wrapper
86
87
88
def _lazy_import_wrapper(
    module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    """Create a lazy import wrapper for a specific function."""

    @functools.cache
    def _get_impl():
        if not has_flashinfer():
            return None
        mod = _get_submodule(module_name)
        return getattr(mod, attr_name, None) if mod else None

    def wrapper(*args, **kwargs):
        impl = _get_impl()
        if impl is None:
            return fallback_fn(*args, **kwargs)
        return impl(*args, **kwargs)

    return wrapper


# Create lazy wrappers for each function
108
109
110
flashinfer_trtllm_bf16_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "trtllm_bf16_moe"
)
111
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
112
113
    "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
114
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
115
116
117
118
119
    "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
    "flashinfer.fused_moe", "cutlass_fused_moe"
)
120
121
122
flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper(
    "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"
)
123
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
124
125
126
127
128
129
130
nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize")
silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper(
    "flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"
)
scaled_fp4_grouped_quantize = _lazy_import_wrapper(
    "flashinfer", "scaled_fp4_grouped_quantize"
)
131
nvfp4_block_scale_interleave = _lazy_import_wrapper(
132
    "flashinfer.fp4_quantization", "block_scale_interleave"
133
)
134
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
135
136
    "flashinfer", "trtllm_fp4_block_scale_moe"
)
137
138
139
140
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
    "flashinfer.autotuner",
    "autotune",
141
142
    fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
143
_is_fi_autotuning: bool = False
144
145


146
147
@functools.cache
def has_flashinfer_comm() -> bool:
148
    """Return `True` if FlashInfer comm module is available."""
149
    return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
150
151
152


@functools.cache
153
def has_flashinfer_nvlink_two_sided() -> bool:
154
    """Return `True` if FlashInfer mnnvl all2all is available."""
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    if not has_flashinfer_comm():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.comm", "Mapping"),
        ("flashinfer.comm.mnnvl", "MnnvlMemory"),
        ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
        ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


173
174
175
176
177
178
179
180
@functools.cache
def has_flashinfer_nvlink_one_sided() -> bool:
    """Return `True` if FlashInfer trtllm_moe_alltoall module is available."""
    if not has_flashinfer_comm():
        return False
    return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None


181
182
@functools.cache
def has_flashinfer_moe() -> bool:
183
    """Return `True` if FlashInfer MoE module is available."""
184
185
186
187
    return (
        has_flashinfer()
        and importlib.util.find_spec("flashinfer.fused_moe") is not None
    )
188
189


190
191
192
193
194
195
196
197
@functools.cache
def has_flashinfer_cutedsl() -> bool:
    """Return ``True`` if FlashInfer cutedsl module is available."""
    return (
        has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None
    )


198
199
200
201
202
203
204
205
206
@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
    """Return `True` if FlashInfer TRTLLM fused MoE is available."""
    if not has_flashinfer_moe():
        return False
    required_functions = [
        ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
        ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
207
        ("flashinfer.fused_moe", "trtllm_mxint4_block_scale_moe"),
208
209
210
211
212
213
214
215
    ]
    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


216
217
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
218
    """Return `True` if FlashInfer CUTLASS fused MoE is available."""
219
    if not has_flashinfer_moe():
220
221
222
223
224
225
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.fused_moe", "cutlass_fused_moe"),
        ("flashinfer", "fp4_quantize"),
226
        ("flashinfer", "nvfp4_block_scale_interleave"),
227
        ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
228
229
230
231
232
233
234
235
236
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


237
238
239
240
241
242
243
244
245
246
@functools.cache
def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool:
    """Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
    if not has_flashinfer_cutedsl():
        return False

    # Check if all required functions are available
    required_functions = [
        ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"),
        ("flashinfer", "scaled_fp4_grouped_quantize"),
247
        ("flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize"),
248
249
250
251
252
253
254
255
256
    ]

    for module_name, attr_name in required_functions:
        mod = _get_submodule(module_name)
        if not mod or not hasattr(mod, attr_name):
            return False
    return True


257
258
@functools.cache
def has_nvidia_artifactory() -> bool:
259
    """Return `True` if NVIDIA's artifactory is accessible.
260

261
262
263
    This checks connectivity to the kernel inference library artifactory
    which is required for downloading certain cubin kernels like TRTLLM FHMA.
    """
264
265
    # If we have pre-downloaded cubins, we can assume the cubins are available.
    if has_flashinfer_cubin():
266
267
        return True

268
269
270
271
272
273
274
275
276
    try:
        # Use a short timeout to avoid blocking for too long
        response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
        accessible = response.status_code == 200
        if accessible:
            logger.debug_once("NVIDIA artifactory is accessible")
        else:
            logger.warning_once(
                "NVIDIA artifactory returned failed status code: %d",
277
278
                response.status_code,
            )
279
280
281
282
283
284
        return accessible
    except Exception as e:
        logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
        return False


285
@functools.cache
286
287
def supports_trtllm_attention() -> bool:
    """
288
289
    TRTLLM attention is supported if the platform is SM100,
    NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
290
    """
291
292
293
294
    # Batch-invariant mode disables TRTLLM attention
    if vllm_is_batch_invariant():
        return False

295
    # Requires SM100 and NVIDIA artifactory to be accessible to download cubins
296
297
298
    return (
        current_platform.is_device_capability_family(100) and has_nvidia_artifactory()
    )
299

300

301
def force_use_trtllm_attention() -> bool | None:
302
    """
303
304
    This function should only be called during initialization stage when vllm config
    is set.
305
    Return `None` if --attention-config.use_trtllm_attention is not set,
306
307
    return `True` if TRTLLM attention is forced to be used,
    return `False` if TRTLLM attention is forced to be not used.
308
    """
309
310
311
312
    from vllm.config import get_current_vllm_config

    vllm_config = get_current_vllm_config()
    return vllm_config.attention_config.use_trtllm_attention
313
314


315
316
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
    """Check if the current configuration supports TRTLLM attention."""
317
318
    if force_use_trtllm_attention() is False:
        return False
319
    has_trtllm = supports_trtllm_attention()
320
    return has_trtllm and (num_qo_heads % num_kv_heads == 0)
321
322


323
def use_trtllm_attention(
324
325
    num_qo_heads: int,
    num_kv_heads: int,
326
327
    num_tokens: int,
    max_seq_len: int,
328
    dcp_world_size: int,
329
    kv_cache_dtype: str,
330
    q_dtype: torch.dtype,
331
    is_prefill: bool,
332
333
    # None means auto-detection, True means force on, False means force off
    force_use_trtllm: bool | None = None,
334
    has_sinks: bool = False,
335
    has_spec: bool = False,
336
) -> bool:
337
    """Return `True` if TRTLLM attention is used."""
338

339
    # CLI argument is set to 0 - respect it
340
    if force_use_trtllm is not None and not force_use_trtllm:
341
342
        return False

343
344
345
346
347
348
349
350
    # Decode context parallel is not supported
    if dcp_world_size > 1:
        logger.warning_once(
            "Trtllm does not support returning LSE and as a result "
            "does not support DCP, reverting to FlashInfer"
        )
        return False

351
352
353
354
355
    # The platform is not supported
    if not supports_trtllm_attention():
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported on this platform, "
356
                "but --attention-config.use_trtllm_attention is set to 1"
357
            )
358
359
360
        return False

    # The combination of query and key heads is not supported
361
    if num_qo_heads % num_kv_heads != 0:
362
363
364
        if force_use_trtllm:
            logger.warning_once(
                "TRTLLM attention is not supported for this combination of "
365
366
                "query and key heads, but --attention-config.use_trtllm_attention is "
                "set to 1"
367
            )
368
369
        return False

370
371
    if has_spec and not is_prefill:
        # Speculative decoding requires TRTLLM attention for decodes
372
        logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
373
374
        return True

375
376
377
378
379
    # Must use TRTLLM attention if query is FP8 quantized
    if q_dtype == current_platform.fp8_dtype():
        logger.info_once("Using TRTLLM attention (query is quantized).")
        return True

380
381
382
    # If sinks are being used, we must use TRTLLM attention as it's
    # the only backend that supports them
    if has_sinks:
383
        logger.info_once("Using TRTLLM attention (required for attention sinks).")
384
385
        return True

386
    if force_use_trtllm is None:
387
        # CLI argument not set - use auto-detection
388
389
        if is_prefill:
            # Prefill auto-detection
390
            use_trtllm = kv_cache_dtype == "auto"
391
392
393
394
            if use_trtllm:
                logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
        else:
            # Decode auto-detection
395
            use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
396
397
            if use_trtllm:
                logger.warning_once("Using TRTLLM decode attention (auto-detected).")
398
399
        return use_trtllm

400
401
402
403
    # CLI argument is set to 1 - respect it
    logger.info_once(
        "Using TRTLLM attention (--attention-config.use_trtllm_attention is set to 1)"
    )
404
405
    return True

406

407
if has_flashinfer():
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    from vllm.utils.torch_utils import direct_register_custom_op

    def _flashinfer_concat_mla_k(
        k: torch.Tensor,
        k_nope: torch.Tensor,
        k_pe: torch.Tensor,
    ) -> None:
        """Custom op wrapper for flashinfer's concat_mla_k.

        This is an in-place operation that concatenates k_nope and k_pe into k.

        The kernel is optimized for DeepSeek V3 dimensions:
        - num_heads=128
        - nope_dim=128
        - rope_dim=64

        Key optimizations:
        - Warp-based processing with software pipelining
        - Vectorized memory access (int2 for nope, int for rope)
        - L2 prefetching for next row while processing current
        - Register reuse for rope values across all heads

        Args:
            k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim].
                Modified in-place.
            k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim].
            k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim].
                  This is broadcast to all heads.
        """
        from flashinfer.concat_ops import concat_mla_k

        concat_mla_k(k, k_nope, k_pe)

    def _flashinfer_concat_mla_k_fake(
        k: torch.Tensor,
        k_nope: torch.Tensor,
        k_pe: torch.Tensor,
    ) -> None:
        return

    # Register flashinfer concat_mla_k custom op
    direct_register_custom_op(
        op_name="flashinfer_concat_mla_k",
        op_func=_flashinfer_concat_mla_k,
        mutates_args=["k"],  # k tensor is modified in-place
        fake_impl=_flashinfer_concat_mla_k_fake,
    )
455
456
457
458
459
460
461
462
463
464
465
466
467

    @torch.library.custom_op(
        "vllm::flashinfer_mm_fp4",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_mm_fp4(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
468
        use_8x4_sf_layout: bool,
469
470
471
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import mm_fp4 as flashinfer_mm_fp4_
472
473

        return flashinfer_mm_fp4_(
474
475
476
477
478
479
480
481
482
            A,
            B,
            A_scale,
            B_scale,
            g_scale,
            dtype,
            block_size=16,
            use_8x4_sf_layout=use_8x4_sf_layout,
            backend=backend,
483
484
485
486
487
        )

    @torch.library.register_fake(
        "vllm::flashinfer_mm_fp4",
    )
488
489
490
491
492
493
494
    def flashinfer_mm_fp4_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        g_scale: torch.Tensor,
        dtype: torch.dtype,
495
        use_8x4_sf_layout: bool,
496
497
        backend: str,
    ) -> torch.Tensor:
498
        return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
499

500
501
502
503
504
505
506
507
508
509
510
511
512
513
    @torch.library.custom_op(
        "vllm::bmm_fp8",
        mutates_args=[],
        device_types="cuda",
    )
    def bmm_fp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
        from flashinfer import bmm_fp8 as bmm_fp8_
514

515
516
        return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)

517
518
519
    @torch.library.register_fake(
        "vllm::bmm_fp8",
    )
520
521
522
523
524
525
526
527
    def bmm_fp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        dtype: torch.dtype,
        backend: str,
    ) -> torch.Tensor:
528
529
530
531
        return torch.empty(
            A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
        )

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    @torch.library.custom_op(
        "vllm::flashinfer_nvfp4_quantize",
        mutates_args=[],
        device_types="cuda",
    )
    def flashinfer_nvfp4_quantize(
        a: torch.Tensor, a_global_sf: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        from flashinfer import SfLayout
        from flashinfer import nvfp4_quantize as nvfp4_quantize_

        return nvfp4_quantize_(
            a, a_global_sf, sfLayout=SfLayout.layout_8x4, do_shuffle=False
        )

    @torch.library.register_fake(
        "vllm::flashinfer_nvfp4_quantize",
    )
    def flashinfer_nvfp4_quantize_fake(
        a: torch.Tensor, a_global_sf: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        m, n = a.shape

        round_up = lambda x, y: (x + y - 1) // y * y

        rounded_m = round_up(m, 8)
        scale_n = n // 16
        rounded_n = round_up(scale_n, 4)

        return torch.empty(m, n // 2, dtype=torch.uint8, device=a.device), torch.empty(
            rounded_m, rounded_n, dtype=torch.uint8, device=a.device
        )

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    @torch.library.custom_op(
        "vllm::mm_mxfp8",
        mutates_args=[],
        device_types="cuda",
    )
    def mm_mxfp8(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        out_dtype: torch.dtype,
        backend: str = "cutlass",
    ) -> torch.Tensor:
        from flashinfer import mm_mxfp8 as mm_mxfp8_

        return mm_mxfp8_(
            A,
            B,
            A_scale,
            B_scale,
            out=None,
            out_dtype=out_dtype,
            backend=backend,
        )

    @torch.library.register_fake(
        "vllm::mm_mxfp8",
    )
    def mm_mxfp8_fake(
        A: torch.Tensor,
        B: torch.Tensor,
        A_scale: torch.Tensor,
        B_scale: torch.Tensor,
        out_dtype: torch.dtype,
        backend: str = "cutlass",
    ) -> torch.Tensor:
        # A is [m, k], B is [k, n] -> output [m, n]
        return torch.empty(A.shape[0], B.shape[1], dtype=out_dtype, device=A.device)


def flashinfer_mm_mxfp8(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str = "cutlass",
) -> torch.Tensor:
    """MXFP8 MM helper - mirrors flashinfer_scaled_fp4_mm API.

    Takes non-transposed weights and handles transpose internally.

    CRITICAL: mm_mxfp8 CUTLASS kernel requires SWIZZLED 1D scales for optimal
    performance and accuracy. Both input and weight scales should be in
    swizzled format from FlashInfer's mxfp8_quantize(is_sf_swizzled_layout=True).
    """
    # a shape [M, K]
    # b shape [K, N]
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[1]  # K dimension must match

    if block_scale_b.ndim != 1:
        raise ValueError(
            "mm_mxfp8 expects 1D swizzled weight scales for CUTLASS; "
            f"got shape={tuple(block_scale_b.shape)}"
        )

    # Output tensor [M, N]
    return mm_mxfp8(
        a,
        b.t(),  # Transpose weight: [N, K] -> [K, N]
        block_scale_a,
        block_scale_b,
        out_dtype,
        backend=backend,
    )

642
643
644
645
646
647
648
649
650
651

def flashinfer_scaled_fp4_mm(
    a: torch.Tensor,
    b: torch.Tensor,
    block_scale_a: torch.Tensor,
    block_scale_b: torch.Tensor,
    alpha: torch.Tensor,
    out_dtype: torch.dtype,
    backend: str,
) -> torch.Tensor:
652
653
654
655
656
    assert a.ndim == 2 and b.ndim == 2
    assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
    assert a.stride(-1) == 1 and b.stride(-1) == 1
    assert a.shape[1] == b.shape[1]

657
    if backend in ("cutlass", "cudnn"):
658
659
660
        block_scale_a = block_scale_a.view(torch.uint8)
        block_scale_b = block_scale_b.view(torch.uint8)

661
662
    use_8x4_sf_layout = True if backend == "trtllm" and a.shape[0] <= 32 else False  # noqa: SIM210

663
664
665
666
667
668
669
    return flashinfer_mm_fp4(
        a,
        b.t(),
        block_scale_a,
        block_scale_b.t(),
        alpha,
        out_dtype,
670
        use_8x4_sf_layout=use_8x4_sf_layout,
671
672
673
674
        backend=backend,
    )


675
def flashinfer_scaled_fp8_mm(
676
677
678
679
680
    a: torch.Tensor,
    b: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    out_dtype: torch.dtype,
681
    bias: torch.Tensor | None = None,
682
) -> torch.Tensor:
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    assert a.ndim == 2 and b.ndim == 2
    assert a.shape[1] == b.shape[0]
    assert scale_a.numel() == 1 and scale_b.numel() == 1
    assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
    assert a.device.type == "cuda" and b.device.type == "cuda"
    assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
    assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"

    output = bmm_fp8(
        a.unsqueeze(0),
        b.unsqueeze(0),
        scale_a,
        scale_b,
        out_dtype,
        "auto",
    ).view(a.shape[0], b.shape[1])

    if bias is not None:
        output = output + bias
    return output


705
706
707
708
709
710
def flashinfer_quant_nvfp4_8x4_sf_layout(
    a: torch.Tensor, a_global_sf: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    return flashinfer_nvfp4_quantize(a, a_global_sf)


711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
flashinfer_fp8_blockscale_gemm = _lazy_import_wrapper(
    "flashinfer.gemm", "fp8_blockscale_gemm_sm90"
)


@functools.cache
def has_flashinfer_fp8_blockscale_gemm() -> bool:
    """Return `True` if FlashInfer block-scale FP8 GEMM is available."""
    return (
        has_flashinfer()
        and current_platform.is_device_capability(90)
        and hasattr(_get_submodule("flashinfer.gemm"), "fp8_blockscale_gemm_sm90")
    )


@functools.cache
def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
    """Return `True` if FlashInfer block-scale FP8 GEMM is supported."""
    return (
        envs.VLLM_BLOCKSCALE_FP8_GEMM_FLASHINFER
        and has_flashinfer_fp8_blockscale_gemm()
    )


def should_use_flashinfer_for_blockscale_fp8_gemm(
    is_flashinfer_supported: bool,
    output_dtype: torch.dtype,
    input: torch.Tensor,
    weight: torch.Tensor,
):
    if not is_flashinfer_supported:
        return False

    # Verify DeepGEMM N/K dims requirements
    # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
746
    # test inside kernels/quantization/test_block_fp8.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
    N_MULTIPLE = 64
    K_MULTIPLE = 128

    weight_dtype = weight.dtype
    input_dtype = input.dtype

    should_use_flashinfer = (
        output_dtype == torch.bfloat16
        and input_dtype == torch.bfloat16
        and weight_dtype == torch.float8_e4m3fn
        and weight.shape[0] % N_MULTIPLE == 0
        and weight.shape[1] % K_MULTIPLE == 0
    )

    return should_use_flashinfer


764
765
__all__ = [
    "has_flashinfer",
766
    "flashinfer_trtllm_fp8_block_scale_moe",
767
    "flashinfer_cutlass_fused_moe",
768
    "flashinfer_cutedsl_grouped_gemm_nt_masked",
769
    "flashinfer_fp4_quantize",
770
771
    "silu_and_mul_scaled_nvfp4_experts_quantize",
    "scaled_fp4_grouped_quantize",
772
    "nvfp4_block_scale_interleave",
773
    "trtllm_fp4_block_scale_moe",
774
    "autotune",
775
    "has_flashinfer_moe",
776
    "has_flashinfer_comm",
777
778
    "has_flashinfer_nvlink_two_sided",
    "has_flashinfer_nvlink_one_sided",
779
    "has_flashinfer_cutlass_fused_moe",
780
    "has_flashinfer_cutedsl_grouped_gemm_nt_masked",
781
    "has_flashinfer_fp8_blockscale_gemm",
782
    "has_nvidia_artifactory",
783
    "supports_trtllm_attention",
784
    "can_use_trtllm_attention",
785
    "use_trtllm_attention",
786
    "flashinfer_scaled_fp4_mm",
787
    "flashinfer_scaled_fp8_mm",
788
    "flashinfer_quant_nvfp4_8x4_sf_layout",
789
790
791
    "flashinfer_fp8_blockscale_gemm",
    "should_use_flashinfer_for_blockscale_fp8_gemm",
    "is_flashinfer_fp8_blockscale_gemm_supported",
792
]