deep_gemm.py 19.7 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
import functools
import importlib
10
import os
11
from collections.abc import Callable
12
from enum import Enum
13
from typing import Any, NoReturn
14
15
16
17

import torch

import vllm.envs as envs
18
from vllm.logger import logger
19
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    get_fp8_min_max,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.import_utils import has_deep_gemm
24
from vllm.utils.math_utils import cdiv
25

Vadim Gimpelson's avatar
merge  
Vadim Gimpelson committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
_DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES: set[str] = {
    "qwen3_5_text",
    "qwen3_5_moe_text",
}


def should_auto_disable_deep_gemm(model_type: str | None) -> bool:
    """Check if DeepGemm should be auto-disabled for this model on Blackwell.

    Returns True if the model is known to have accuracy degradation with
    DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
    """
    if model_type is None:
        return False
    if not current_platform.is_device_capability_family(100):
        return False
    return model_type in _DEEPGEMM_BLACKWELL_EXCLUDED_MODEL_TYPES

44

45
46
47
48
49
50
51
52
53
54
55
class DeepGemmQuantScaleFMT(Enum):
    # Float32 scales in Float32 tensor
    FLOAT32 = 0
    # Compute float32 scales and ceil the scales to UE8M0.
    # Keep the scales in Float32 tensor.
    FLOAT32_CEIL_UE8M0 = 1
    # Compute float32 scales and ceil the scales to UE8M0.
    # Pack the scales into a int32 tensor where each int32
    # element contains 4 scale values.
    UE8M0 = 2

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    @classmethod
    def init_oracle_cache(cls) -> None:
        """Initialize the oracle decision and store it in the class cache"""
        cached = getattr(cls, "_oracle_cache", None)
        if cached is not None:
            return

        use_e8m0 = (
            envs.VLLM_USE_DEEP_GEMM_E8M0
            and is_deep_gemm_supported()
            and (_fp8_gemm_nt_impl is not None)
        )
        if not use_e8m0:
            cls._oracle_cache = cls.FLOAT32  # type: ignore
            return

        cls._oracle_cache = (  # type: ignore
            cls.UE8M0
74
            if current_platform.is_device_capability_family(100)
75
            else cls.FLOAT32_CEIL_UE8M0
76
77
        )

78
79
80
81
82
83
84
    @classmethod
    def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
        """Return the pre-initialized oracle decision"""
        cached = getattr(cls, "_oracle_cache", None)
        assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
        return cached

85

86
87
@functools.cache
def is_deep_gemm_supported() -> bool:
88
    """Return `True` if DeepGEMM is supported on the current platform.
89
90
    Currently, only Hopper and Blackwell GPUs are supported.
    """
91
    is_supported_arch = current_platform.is_cuda() and (
92
        current_platform.is_device_capability(90)
93
        or current_platform.is_device_capability_family(100)
94
    )
95
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
96
97


98
@functools.cache
99
def is_deep_gemm_e8m0_used() -> bool:
100
    """Return `True` if vLLM is configured to use DeepGEMM "
101
    "E8M0 scale on a Hopper or Blackwell-class GPU.
102
    """
103
    if not is_deep_gemm_supported():
104
        logger.debug_once(
105
106
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
        )
107
108
        return False

109
    _lazy_init()
110

111
    if _fp8_gemm_nt_impl is None:
112
113
114
        logger.info_once(
            "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found", scope="local"
        )
115
116
        return False

117
    if envs.VLLM_USE_DEEP_GEMM_E8M0:
118
        logger.info_once("DeepGEMM E8M0 enabled on current platform.", scope="local")
119
120
        return True

121
    logger.info_once("DeepGEMM E8M0 disabled on current configuration.", scope="local")
122
    return False
123
124
125
126
127


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
128
        "DeepGEMM backend is not available or outdated. Please install or "
129
130
        "update the `deep_gemm` to a newer version to enable FP8 kernels."
    )
131
132


133
134
135
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
136
137
138
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
139
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
140
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
141
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
142
143
144
145


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
146
147
148
149
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
    global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
    global _get_paged_mqa_logits_metadata_impl
    global _get_mn_major_tma_aligned_tensor_impl
150
    global _get_mk_alignment_for_contiguous_layout_impl
151
    global _transform_sf_into_required_layout_impl
152
    # fast path
153
154
155
156
157
158
159
    if (
        _fp8_gemm_nt_impl is not None
        or _grouped_impl is not None
        or _grouped_masked_impl is not None
        or _fp8_mqa_logits_impl is not None
        or _fp8_paged_mqa_logits_impl is not None
        or _get_paged_mqa_logits_metadata_impl is not None
160
        or _get_mk_alignment_for_contiguous_layout_impl is not None
161
        or _transform_sf_into_required_layout_impl is not None
162
    ):
163
164
165
166
167
        return

    if not has_deep_gemm():
        return

168
    # Set up deep_gemm cache path
169
    DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
170
171
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
172
173
            envs.VLLM_CACHE_ROOT, "deep_gemm"
        )
174

175
176
    _dg = importlib.import_module("deep_gemm")

177
178
179
    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
180
181
182
    _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
    _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
    _get_paged_mqa_logits_metadata_impl = getattr(
183
184
        _dg, "get_paged_mqa_logits_metadata", None
    )
185
    _get_mn_major_tma_aligned_tensor_impl = getattr(
186
187
        _dg, "get_mn_major_tma_aligned_tensor", None
    )
188
189
190
    _get_mk_alignment_for_contiguous_layout_impl = getattr(
        _dg, "get_mk_alignment_for_contiguous_layout", None
    )
191
192
193
    _transform_sf_into_required_layout_impl = getattr(
        _dg, "transform_sf_into_required_layout", None
    )
194
    DeepGemmQuantScaleFMT.init_oracle_cache()
195
196


197
198
199
200
201
202
def get_num_sms() -> int:
    _lazy_init()
    _dg = importlib.import_module("deep_gemm")
    return int(_dg.get_num_sms())


203
204
205
206
207
208
209
210
211
@functools.cache
def get_mk_alignment_for_contiguous_layout() -> list[int]:
    _lazy_init()
    if _get_mk_alignment_for_contiguous_layout_impl is None:
        return _missing()
    mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
    return [mk_align_size, mk_align_size]


212
213
214
215
216
217
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)
218
219
220


def fp8_gemm_nt(*args, **kwargs):
221
    _lazy_init()
222
223
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
224
225
226
227
228
229
    if "is_deep_gemm_e8m0_used" in kwargs:
        use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
        del kwargs["is_deep_gemm_e8m0_used"]
    else:
        use_ue8m0 = is_deep_gemm_e8m0_used()
    return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
230
231
232


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
233
    _lazy_init()
234
235
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
236
237
238
    return _grouped_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
239
240
241


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
242
    _lazy_init()
243
244
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
245
    return _grouped_masked_impl(
246
247
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
248
249


250
251
252
253
254
255
256
257
258
def transform_sf_into_required_layout(*args, **kwargs):
    _lazy_init()
    if _transform_sf_into_required_layout_impl is None:
        return _missing(*args, **kwargs)
    return _transform_sf_into_required_layout_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )


259
260
261
262
263
264
def fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
265
    clean_logits: bool,
266
267
268
269
270
271
272
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
273
274
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
            with dtype `torch.float32`.
275
276
277
278
279
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.
280
        clean_logits: Whether to clean the unfilled logits into `-inf`.
281
282
283
284
285
286
287

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
    if _fp8_mqa_logits_impl is None:
        return _missing()
288
289
290
    return _fp8_mqa_logits_impl(
        q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits
    )
291
292


293
294
295
def get_paged_mqa_logits_metadata(
    context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
        Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
311
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
312
313
314
315
316
317
318
319
320
321


def fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
322
    clean_logits: bool,
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.
340
        clean_logits: Whether to clean the unfilled logits into `-inf`.
341
342
343
344
345
346
347
348

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
    if _fp8_paged_mqa_logits_impl is None:
        return _missing()
349
350
351
352
353
354
355
356
    return _fp8_paged_mqa_logits_impl(
        q_fp8,
        kv_cache_fp8,
        weights,
        context_lens,
        block_tables,
        schedule_metadata,
        max_model_len,
357
        clean_logits=clean_logits,
358
    )
359
360


361
362
363
364
365
366
367
368
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


369
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/csrc/utils/math.hpp#L19
370
def get_tma_aligned_size(x: int, element_size: int) -> int:
371
372
373
    return _align(x, 16 // element_size)


374
375
376
377
DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
378
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
379
def per_block_cast_to_fp8(
380
381
    x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
382
    fp8_dtype = current_platform.fp8_dtype()
383
384
385
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
386
387
388
    x_padded = torch.zeros(
        (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
    )
389
390
391
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
392
393
    _, fp8_max = get_fp8_min_max()
    sf = x_amax / fp8_max
394
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
395
    x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
396
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
397
398
        x_view.size(0), x_view.size(2)
    )
399
400
401
402
403
404


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
405
    error, causing `torch.testing.assert_close` to fail.  Instead of checking
406
    every element, we compute a cosine-style similarity over the whole tensor
407
    and report `1 - sim`.  Once kernel accuracy improves this helper can be
408
409
410
411
412
413
414
415
416
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


417
def should_use_deepgemm_for_fp8_linear(
418
419
    output_dtype: torch.dtype,
    weight: torch.Tensor,
420
    supports_deep_gemm: bool | None = None,
421
):
422
423
    if supports_deep_gemm is None:
        supports_deep_gemm = is_deep_gemm_supported()
424
425
426

    # Verify DeepGEMM N/K dims requirements
    # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
427
    # test inside kernels/quantization/test_block_fp8.py
428
429
430
    N_MULTIPLE = 64
    K_MULTIPLE = 128

431
432
433
    return (
        supports_deep_gemm
        and output_dtype == torch.bfloat16
434
435
        and weight.shape[0] % N_MULTIPLE == 0
        and weight.shape[1] % K_MULTIPLE == 0
436
    )
437
438


439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
def fp8_mqa_logits_torch(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging (CUDA fallback).

    This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    kv_fp8, scale = kv
    seq_len_kv = kv_fp8.shape[0]
    k = kv_fp8.to(torch.bfloat16)
    q = q.to(torch.bfloat16)

    mask_lo = (
        torch.arange(0, seq_len_kv, device=q.device)[None, :] >= cu_seqlen_ks[:, None]
    )
    mask_hi = (
        torch.arange(0, seq_len_kv, device=q.device)[None, :] < cu_seqlen_ke[:, None]
    )
    mask = mask_lo & mask_hi

    score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
    logits = logits.masked_fill(~mask, float("-inf"))

    return logits


def fp8_paged_mqa_logits_torch(
    q: torch.Tensor,
    kv_cache: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache (CUDA fallback).

    This is a pure PyTorch fallback for CUDA when DeepGEMM is not available.
    Handles head_dim = 132 (128 + 4 for RoPE).

    Args:
        q: Query tensor of shape [B, next_n, H, D].
        kv_cache: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    fp8_dtype = current_platform.fp8_dtype()
    batch_size, next_n, heads, dim = q.size()
    kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
    scale = scale.contiguous().view(torch.float)
    q = q.float()
    kv_cache = kv_cache.view(fp8_dtype).float() * scale
    num_blocks, block_size, _, dim = kv_cache.size()
    logits = torch.full(
        [batch_size * next_n, max_model_len],
        float("-inf"),
        device=q.device,
        dtype=torch.float32,
    )
    for i in range(batch_size):
        context_len = context_lens[i].item()
        q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
        weight_slice = (
            weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
        )
        for block_idx in range(cdiv(context_len, block_size)):
            block_id = block_tables[i][block_idx]
            qx, kx = q[i], kv_cache[block_id]
            k_offsets = torch.arange(
                block_idx * block_size, (block_idx + 1) * block_size, device=q.device
            )
            mask = (k_offsets[None, :] < context_len) & (
                k_offsets[None, :] <= q_offsets[:, None]
            )
            s = torch.where(
                mask[None, :, :],
                (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
                    logits.dtype
                ),
                float("-inf"),
            )
            s = torch.relu(s) * weight_slice[..., None]
            s = s.sum(dim=0)
            logits[
                i * next_n : (i + 1) * next_n,
                block_idx * block_size : (block_idx + 1) * block_size,
            ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
    return logits


558
559
__all__ = [
    "calc_diff",
560
    "DeepGemmQuantScaleFMT",
561
562
563
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
564
    "fp8_mqa_logits",
565
    "fp8_mqa_logits_torch",
566
    "fp8_paged_mqa_logits",
567
    "fp8_paged_mqa_logits_torch",
568
    "get_paged_mqa_logits_metadata",
569
    "per_block_cast_to_fp8",
570
    "is_deep_gemm_e8m0_used",
571
    "is_deep_gemm_supported",
572
    "get_num_sms",
573
    "should_use_deepgemm_for_fp8_linear",
574
    "get_col_major_tma_aligned_tensor",
575
    "get_mk_alignment_for_contiguous_layout",
576
]