deep_gemm.py 19.4 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
import functools
import importlib
10
import os
11
from collections.abc import Callable
12
from enum import Enum
13
from typing import Any, NoReturn
14
15
16
17

import torch

import vllm.envs as envs
18
from vllm.logger import logger
19
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    get_fp8_min_max,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.import_utils import has_deep_gemm
24
from vllm.utils.math_utils import cdiv
25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES: set[str] = {
    "qwen3_5_text",
    "qwen3_5_moe_text",
}


def should_auto_disable_deep_gemm(model_type: str | None) -> bool:
    """Check if DeepGemm should be auto-disabled for this model on Blackwell.

    Returns True if the model is known to have accuracy degradation with
    DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
    """
    if model_type is None:
        return False
    if not current_platform.is_device_capability_family(100):
        return False
    return model_type in _DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES

44

45
46
47
48
49
50
51
52
53
54
55
class DeepGemmQuantScaleFMT(Enum):
    # Float32 scales in Float32 tensor
    FLOAT32 = 0
    # Compute float32 scales and ceil the scales to UE8M0.
    # Keep the scales in Float32 tensor.
    FLOAT32_CEIL_UE8M0 = 1
    # Compute float32 scales and ceil the scales to UE8M0.
    # Pack the scales into a int32 tensor where each int32
    # element contains 4 scale values.
    UE8M0 = 2

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    @classmethod
    def init_oracle_cache(cls) -> None:
        """Initialize the oracle decision and store it in the class cache"""
        cached = getattr(cls, "_oracle_cache", None)
        if cached is not None:
            return

        use_e8m0 = (
            envs.VLLM_USE_DEEP_GEMM_E8M0
            and is_deep_gemm_supported()
            and (_fp8_gemm_nt_impl is not None)
        )
        if not use_e8m0:
            cls._oracle_cache = cls.FLOAT32  # type: ignore
            return

        cls._oracle_cache = (  # type: ignore
            cls.UE8M0
74
            if current_platform.is_device_capability_family(100)
75
            else cls.FLOAT32_CEIL_UE8M0
76
77
        )

78
79
80
81
82
83
84
    @classmethod
    def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
        """Return the pre-initialized oracle decision"""
        cached = getattr(cls, "_oracle_cache", None)
        assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
        return cached

85

86
87
@functools.cache
def is_deep_gemm_supported() -> bool:
88
    """Return `True` if DeepGEMM is supported on the current platform.
89
90
    Currently, only Hopper and Blackwell GPUs are supported.
    """
91
    is_supported_arch = current_platform.support_deep_gemm()
92
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
93
94


95
@functools.cache
96
def is_deep_gemm_e8m0_used() -> bool:
97
    """Return `True` if vLLM is configured to use DeepGEMM "
98
    "E8M0 scale on a Hopper or Blackwell-class GPU.
99
    """
100
    if not is_deep_gemm_supported():
101
        logger.debug_once(
102
103
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
        )
104
105
        return False

106
    _lazy_init()
107

108
    if _fp8_gemm_nt_impl is None:
109
        logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
110
111
        return False

112
    if envs.VLLM_USE_DEEP_GEMM_E8M0:
113
        logger.info_once("DeepGEMM E8M0 enabled on current platform.")
114
115
        return True

116
    logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
117
    return False
118
119
120
121
122


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
123
        "DeepGEMM backend is not available or outdated. Please install or "
124
125
        "update the `deep_gemm` to a newer version to enable FP8 kernels."
    )
126
127


128
_cublaslt_gemm_nt_impl: Callable[..., Any] | None = None
129
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
130
_fp8_einsum_impl: Callable[..., Any] | None = None
131
132
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
133
134
135
_grouped_fp4_impl: Callable[..., Any] | None = None
_fp8_fp4_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_fp4_paged_mqa_logits_impl: Callable[..., Any] | None = None
136
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
137
_tf32_hc_prenorm_gemm_impl: Callable[..., Any] | None = None
138
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
139
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
140
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
141
142


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def _import_deep_gemm():
    """Import the deep_gemm module.

    Prefers an externally installed ``deep_gemm`` package (so users can
    pin a specific version), then falls back to the vendored copy bundled
    in the vLLM wheel.

    Returns ``None`` when neither source is usable.
    """
    # 1. Try the external (pip-installed) package first.
    try:
        module = importlib.import_module("deep_gemm")
        logger.debug_once("Imported deep_gemm module from site-packages")
        return module
    except ImportError:
        logger.debug_once(
            "deep_gemm not found in site-packages, "
            "trying vendored vllm.third_party.deep_gemm"
        )

    # 2. Fall back to the vendored copy bundled in the vLLM wheel.
    try:
        module = importlib.import_module("vllm.third_party.deep_gemm")
        logger.debug_once("Imported deep_gemm module from vllm.third_party.deep_gemm")
        return module
    except ImportError:
        logger.debug_once("Vendored deep_gemm not found either")
    except Exception as e:
        # The vendored module may raise RuntimeError during _C.init()
        # if JIT include files are missing (e.g. incomplete wheel).
        logger.warning_once("Failed to import vendored deep_gemm: %s", e)

    return None


178
179
def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
180
181
182
183
    global _cublaslt_gemm_nt_impl
    global _fp8_gemm_nt_impl, _fp8_einsum_impl
    global _grouped_impl, _grouped_masked_impl, _grouped_fp4_impl
    global _fp8_fp4_mqa_logits_impl, _fp8_fp4_paged_mqa_logits_impl
184
    global _get_paged_mqa_logits_metadata_impl
185
    global _tf32_hc_prenorm_gemm_impl
186
    global _get_mn_major_tma_aligned_tensor_impl
187
    global _get_mk_alignment_for_contiguous_layout_impl
188
    global _transform_sf_into_required_layout_impl
189
    # fast path
190
    if (
191
192
193
        _cublaslt_gemm_nt_impl is not None
        or _fp8_gemm_nt_impl is not None
        or _fp8_einsum_impl is not None
194
195
        or _grouped_impl is not None
        or _grouped_masked_impl is not None
196
197
198
        or _grouped_fp4_impl is not None
        or _fp8_fp4_mqa_logits_impl is not None
        or _fp8_fp4_paged_mqa_logits_impl is not None
199
        or _get_paged_mqa_logits_metadata_impl is not None
200
        or _tf32_hc_prenorm_gemm_impl is not None
201
        or _get_mk_alignment_for_contiguous_layout_impl is not None
202
        or _transform_sf_into_required_layout_impl is not None
203
    ):
204
205
206
207
208
        return

    if not has_deep_gemm():
        return

209
    # Set up deep_gemm cache path
210
    DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
211
212
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
213
214
            envs.VLLM_CACHE_ROOT, "deep_gemm"
        )
215

216
217
218
    _dg = _import_deep_gemm()
    if _dg is None:
        return
219

220
    _cublaslt_gemm_nt_impl = getattr(_dg, "cublaslt_gemm_nt", None)
221
    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
222
    _fp8_einsum_impl = getattr(_dg, "fp8_einsum", None)
223
224
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
225
226
227
228
229
    _grouped_fp4_impl = getattr(_dg, "m_grouped_fp8_fp4_gemm_nt_contiguous", None)
    # DeepGEMM exposes fp8_fp4_*_mqa_logits as the canonical symbols that
    # handle both the FP8 and FP4 Q/K paths via a tuple-typed `q`.
    _fp8_fp4_mqa_logits_impl = getattr(_dg, "fp8_fp4_mqa_logits", None)
    _fp8_fp4_paged_mqa_logits_impl = getattr(_dg, "fp8_fp4_paged_mqa_logits", None)
230
    _get_paged_mqa_logits_metadata_impl = getattr(
231
232
        _dg, "get_paged_mqa_logits_metadata", None
    )
233
    _tf32_hc_prenorm_gemm_impl = getattr(_dg, "tf32_hc_prenorm_gemm", None)
234
    _get_mn_major_tma_aligned_tensor_impl = getattr(
235
236
        _dg, "get_mn_major_tma_aligned_tensor", None
    )
237
238
239
    _get_mk_alignment_for_contiguous_layout_impl = getattr(
        _dg, "get_mk_alignment_for_contiguous_layout", None
    )
240
241
242
    _transform_sf_into_required_layout_impl = getattr(
        _dg, "transform_sf_into_required_layout", None
    )
243
    DeepGemmQuantScaleFMT.init_oracle_cache()
244
245


246
247
def get_num_sms() -> int:
    _lazy_init()
248
249
250
251
252
253
254
255
256
257
258
259
    dg = _import_deep_gemm()
    if dg is None:
        raise RuntimeError("DeepGEMM is not available")
    return int(dg.get_num_sms())


def set_num_sms(num_sms: int) -> None:
    _lazy_init()
    dg = _import_deep_gemm()
    if dg is None:
        raise RuntimeError("DeepGEMM is not available")
    dg.set_num_sms(num_sms)
260
261


262
263
264
265
266
267
268
269
270
@functools.cache
def get_mk_alignment_for_contiguous_layout() -> list[int]:
    _lazy_init()
    if _get_mk_alignment_for_contiguous_layout_impl is None:
        return _missing()
    mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
    return [mk_align_size, mk_align_size]


271
272
273
274
275
276
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)
277
278


279
280
281
282
283
284
285
def cublaslt_gemm_nt(*args, **kwargs):
    _lazy_init()
    if _cublaslt_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _cublaslt_gemm_nt_impl(*args, **kwargs)


286
def fp8_gemm_nt(*args, **kwargs):
287
    _lazy_init()
288
289
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
290
291
292
293
294
295
    if "is_deep_gemm_e8m0_used" in kwargs:
        use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
        del kwargs["is_deep_gemm_e8m0_used"]
    else:
        use_ue8m0 = is_deep_gemm_e8m0_used()
    return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
296
297


298
299
300
301
302
303
304
def fp8_einsum(*args, **kwargs):
    _lazy_init()
    if _fp8_einsum_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_einsum_impl(*args, **kwargs)


305
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
306
    _lazy_init()
307
308
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
309
310
311
    return _grouped_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
312
313


314
315
316
317
318
319
320
321
322
def m_grouped_fp8_fp4_gemm_nt_contiguous(*args, **kwargs):
    _lazy_init()
    if _grouped_fp4_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_fp4_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )


323
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
324
    _lazy_init()
325
326
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
327
    return _grouped_masked_impl(
328
329
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
330
331


332
333
334
335
336
337
338
339
340
def transform_sf_into_required_layout(*args, **kwargs):
    _lazy_init()
    if _transform_sf_into_required_layout_impl is None:
        return _missing(*args, **kwargs)
    return _transform_sf_into_required_layout_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )


341
342
def fp8_fp4_mqa_logits(
    q: tuple[torch.Tensor, torch.Tensor | None],
343
344
345
346
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
347
    clean_logits: bool,
348
) -> torch.Tensor:
349
350
351
352
353
354
    """Compute MQA logits for a single sequence without KV paging.

    Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes
    ``q = (values, scales_or_None)`` where ``scales`` is None for FP8 Q
    (per-token scale is folded into ``weights``) and a packed block-scale
    tensor for MXFP4 Q.
355
356

    Args:
357
358
359
360
361
362
        q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is [M, H, D]
            float8_e4m3fn and q_scale is None (per-token scale is folded
            into ``weights``). FP4 path: q_values is packed uint8 and
            q_scale is the companion block-scale tensor.
        kv: Tuple `(k_packed, k_scales)` — FP8 layout is [N, D]
            float8_e4m3fn plus fp32 scales [N]; FP4 layout is packed uint8.
363
        weights: weights of shape [M, H], dtype `torch.float32`.
364
365
366
367
        cu_seqlen_ks: Start indices (inclusive) for valid K per query
            position, shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query
            position, shape [M], dtype int32.
368
        clean_logits: Whether to clean the unfilled logits into `-inf`.
369
370
371
372
373

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
374
    if _fp8_fp4_mqa_logits_impl is None:
375
        return _missing()
376
377
378
379
380
381
382
    return _fp8_fp4_mqa_logits_impl(
        q,
        kv,
        weights,
        cu_seqlen_ks,
        cu_seqlen_ke,
        clean_logits=clean_logits,
383
    )
384
385


386
387
388
def get_paged_mqa_logits_metadata(
    context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
389
390
391
392
393
394
395
396
397
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
398
        Backend-specific tensor consumed by `fp8_fp4_paged_mqa_logits` to
399
400
401
402
403
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
404
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
405
406


407
408
409
def fp8_fp4_paged_mqa_logits(
    q: tuple[torch.Tensor, torch.Tensor | None],
    kv_cache: torch.Tensor,
410
411
412
413
414
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
415
    clean_logits: bool,
416
) -> torch.Tensor:
417
418
419
420
421
    """Compute MQA logits using a paged KV-cache.

    Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes
    ``q = (values, scales_or_None)``; pass ``(q_tensor, None)`` for the FP8
    path and ``(q_values, q_scale)`` for MXFP4.
422
423

    Args:
424
425
426
427
428
429
430
        q: Tuple ``(q_values, q_scale)``. FP8 path: q_values is
            [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path:
            q_values is packed uint8 and q_scale is the companion
            block-scale tensor.
        kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1,
            D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos)
            storing the float dequant scale.
431
432
433
434
435
436
437
438
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.
439
        clean_logits: Whether to clean the unfilled logits into `-inf`.
440
441
442
443
444
445

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
446
    if _fp8_fp4_paged_mqa_logits_impl is None:
447
        return _missing()
448
449
450
    return _fp8_fp4_paged_mqa_logits_impl(
        q,
        kv_cache,
451
452
453
454
455
        weights,
        context_lens,
        block_tables,
        schedule_metadata,
        max_model_len,
456
        clean_logits=clean_logits,
457
    )
458
459


460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def tf32_hc_prenorm_gemm(
    x: torch.Tensor,
    fn: torch.Tensor,
    out: torch.Tensor,
    sqrsum: torch.Tensor,
    num_split: int,
) -> torch.Tensor:
    """
    Perform the following computation:
        out = x.float() @ fn.T
        sqrsum = x.float().square().sum(-1)

    See the caller function for shape requirement
    """
    _lazy_init()
    if _tf32_hc_prenorm_gemm_impl is None:
        return _missing()
    return _tf32_hc_prenorm_gemm_impl(
        x,
        fn,
        out,
        sqrsum,
        num_split,
    )


486
487
488
489
490
491
492
493
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


494
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
495
def get_tma_aligned_size(x: int, element_size: int) -> int:
496
497
498
    return _align(x, 16 // element_size)


499
500
501
502
DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
503
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
504
def per_block_cast_to_fp8(
505
506
    x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
507
    fp8_dtype = current_platform.fp8_dtype()
508
509
510
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
511
512
513
    x_padded = torch.zeros(
        (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
    )
514
515
516
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
517
518
    _, fp8_max = get_fp8_min_max()
    sf = x_amax / fp8_max
519
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
520
    x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
521
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
522
523
        x_view.size(0), x_view.size(2)
    )
524
525
526
527
528
529


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
530
    error, causing `torch.testing.assert_close` to fail.  Instead of checking
531
    every element, we compute a cosine-style similarity over the whole tensor
532
    and report `1 - sim`.  Once kernel accuracy improves this helper can be
533
534
535
536
537
538
539
540
541
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


542
def should_use_deepgemm_for_fp8_linear(
543
    output_dtype: torch.dtype,
544
    weight_shape: tuple[int, int],
545
    supports_deep_gemm: bool | None = None,
546
):
547
548
    if supports_deep_gemm is None:
        supports_deep_gemm = is_deep_gemm_supported()
549
550
551

    # Verify DeepGEMM N/K dims requirements
    # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
552
    # test inside kernels/quantization/test_block_fp8.py
553
554
555
    N_MULTIPLE = 64
    K_MULTIPLE = 128

556
557
558
    return (
        supports_deep_gemm
        and output_dtype == torch.bfloat16
559
560
        and weight_shape[0] % N_MULTIPLE == 0
        and weight_shape[1] % K_MULTIPLE == 0
561
    )
562
563


564
565
__all__ = [
    "calc_diff",
566
    "DeepGemmQuantScaleFMT",
567
    "fp8_gemm_nt",
568
    "fp8_einsum",
569
    "m_grouped_fp8_gemm_nt_contiguous",
570
    "m_grouped_fp8_fp4_gemm_nt_contiguous",
571
    "fp8_m_grouped_gemm_nt_masked",
572
573
    "fp8_fp4_mqa_logits",
    "fp8_fp4_paged_mqa_logits",
574
    "get_paged_mqa_logits_metadata",
575
    "per_block_cast_to_fp8",
576
    "is_deep_gemm_e8m0_used",
577
    "is_deep_gemm_supported",
578
    "get_num_sms",
579
    "set_num_sms",
580
    "should_use_deepgemm_for_fp8_linear",
581
    "get_col_major_tma_aligned_tensor",
582
    "get_mk_alignment_for_contiguous_layout",
583
]