deep_gemm.py 14.1 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
import functools
import importlib
10
import os
11
from collections.abc import Callable
12
from enum import Enum
13
from typing import Any, NoReturn
14
15
16
17

import torch

import vllm.envs as envs
18
from vllm.logger import logger
19
20
21
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    get_fp8_min_max,
)
22
from vllm.platforms import current_platform
23
from vllm.utils.import_utils import has_deep_gemm
24
from vllm.utils.math_utils import cdiv
25
26


27
28
29
30
31
32
33
34
35
36
37
class DeepGemmQuantScaleFMT(Enum):
    # Float32 scales in Float32 tensor
    FLOAT32 = 0
    # Compute float32 scales and ceil the scales to UE8M0.
    # Keep the scales in Float32 tensor.
    FLOAT32_CEIL_UE8M0 = 1
    # Compute float32 scales and ceil the scales to UE8M0.
    # Pack the scales into a int32 tensor where each int32
    # element contains 4 scale values.
    UE8M0 = 2

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    @classmethod
    def init_oracle_cache(cls) -> None:
        """Initialize the oracle decision and store it in the class cache"""
        cached = getattr(cls, "_oracle_cache", None)
        if cached is not None:
            return

        use_e8m0 = (
            envs.VLLM_USE_DEEP_GEMM_E8M0
            and is_deep_gemm_supported()
            and (_fp8_gemm_nt_impl is not None)
        )
        if not use_e8m0:
            cls._oracle_cache = cls.FLOAT32  # type: ignore
            return

        cls._oracle_cache = (  # type: ignore
            cls.UE8M0
56
            if current_platform.is_device_capability_family(100)
57
            else cls.FLOAT32_CEIL_UE8M0
58
59
        )

60
61
62
63
64
65
66
    @classmethod
    def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
        """Return the pre-initialized oracle decision"""
        cached = getattr(cls, "_oracle_cache", None)
        assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
        return cached

67

68
69
@functools.cache
def is_deep_gemm_supported() -> bool:
70
    """Return `True` if DeepGEMM is supported on the current platform.
71
72
    Currently, only Hopper and Blackwell GPUs are supported.
    """
73
    is_supported_arch = current_platform.is_cuda() and (
74
        current_platform.is_device_capability(90)
75
        or current_platform.is_device_capability_family(100)
76
    )
77
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
78
79


80
@functools.cache
81
def is_deep_gemm_e8m0_used() -> bool:
82
    """Return `True` if vLLM is configured to use DeepGEMM "
83
    "E8M0 scale on a Hopper or Blackwell-class GPU.
84
    """
85
    if not is_deep_gemm_supported():
86
        logger.debug_once(
87
88
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
        )
89
90
        return False

91
    _lazy_init()
92

93
    if _fp8_gemm_nt_impl is None:
94
        logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
95
96
        return False

97
98
    if envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.info_once("DeepGEMM E8M0 enabled on current platform.")
99
100
101
102
        return True

    logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
    return False
103
104
105
106
107


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
108
        "DeepGEMM backend is not available or outdated. Please install or "
109
110
        "update the `deep_gemm` to a newer version to enable FP8 kernels."
    )
111
112


113
114
115
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
116
117
118
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
119
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
120
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
121
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
122
123
124
125


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
126
127
128
129
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
    global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
    global _get_paged_mqa_logits_metadata_impl
    global _get_mn_major_tma_aligned_tensor_impl
130
    global _get_mk_alignment_for_contiguous_layout_impl
131
    global _transform_sf_into_required_layout_impl
132
    # fast path
133
134
135
136
137
138
139
    if (
        _fp8_gemm_nt_impl is not None
        or _grouped_impl is not None
        or _grouped_masked_impl is not None
        or _fp8_mqa_logits_impl is not None
        or _fp8_paged_mqa_logits_impl is not None
        or _get_paged_mqa_logits_metadata_impl is not None
140
        or _get_mk_alignment_for_contiguous_layout_impl is not None
141
        or _transform_sf_into_required_layout_impl is not None
142
    ):
143
144
145
146
147
        return

    if not has_deep_gemm():
        return

148
    # Set up deep_gemm cache path
149
    DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
150
151
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
152
153
            envs.VLLM_CACHE_ROOT, "deep_gemm"
        )
154

155
156
    _dg = importlib.import_module("deep_gemm")

157
158
159
    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
160
161
162
    _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
    _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
    _get_paged_mqa_logits_metadata_impl = getattr(
163
164
        _dg, "get_paged_mqa_logits_metadata", None
    )
165
    _get_mn_major_tma_aligned_tensor_impl = getattr(
166
167
        _dg, "get_mn_major_tma_aligned_tensor", None
    )
168
169
170
    _get_mk_alignment_for_contiguous_layout_impl = getattr(
        _dg, "get_mk_alignment_for_contiguous_layout", None
    )
171
172
173
    _transform_sf_into_required_layout_impl = getattr(
        _dg, "transform_sf_into_required_layout", None
    )
174
    DeepGemmQuantScaleFMT.init_oracle_cache()
175
176


177
178
179
180
181
182
def get_num_sms() -> int:
    _lazy_init()
    _dg = importlib.import_module("deep_gemm")
    return int(_dg.get_num_sms())


183
184
185
186
187
188
189
190
191
@functools.cache
def get_mk_alignment_for_contiguous_layout() -> list[int]:
    _lazy_init()
    if _get_mk_alignment_for_contiguous_layout_impl is None:
        return _missing()
    mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
    return [mk_align_size, mk_align_size]


192
193
194
195
196
197
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)
198
199
200


def fp8_gemm_nt(*args, **kwargs):
201
    _lazy_init()
202
203
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
204
205
206
207
208
209
    if "is_deep_gemm_e8m0_used" in kwargs:
        use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
        del kwargs["is_deep_gemm_e8m0_used"]
    else:
        use_ue8m0 = is_deep_gemm_e8m0_used()
    return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
210
211
212


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
213
    _lazy_init()
214
215
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
216
217
218
    return _grouped_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
219
220
221


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
222
    _lazy_init()
223
224
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
225
    return _grouped_masked_impl(
226
227
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
228
229


230
231
232
233
234
235
236
237
238
def transform_sf_into_required_layout(*args, **kwargs):
    _lazy_init()
    if _transform_sf_into_required_layout_impl is None:
        return _missing(*args, **kwargs)
    return _transform_sf_into_required_layout_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )


239
240
241
242
243
244
245
246
247
248
249
250
251
def fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
252
253
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
            with dtype `torch.float32`.
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
    if _fp8_mqa_logits_impl is None:
        return _missing()
    return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)


269
270
271
def get_paged_mqa_logits_metadata(
    context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
        Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
287
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322


def fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
    if _fp8_paged_mqa_logits_impl is None:
        return _missing()
323
324
325
326
327
328
329
330
331
332
    return _fp8_paged_mqa_logits_impl(
        q_fp8,
        kv_cache_fp8,
        weights,
        context_lens,
        block_tables,
        schedule_metadata,
        max_model_len,
        clean_logits=True,
    )
333
334


335
336
337
338
339
340
341
342
343
344
345
346
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
347
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
348
def per_block_cast_to_fp8(
349
350
    x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
351
    fp8_dtype = current_platform.fp8_dtype()
352
353
354
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
355
356
357
    x_padded = torch.zeros(
        (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
    )
358
359
360
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
361
362
    _, fp8_max = get_fp8_min_max()
    sf = x_amax / fp8_max
363
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
364
    x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
365
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
366
367
        x_view.size(0), x_view.size(2)
    )
368
369
370
371
372
373


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
374
    error, causing `torch.testing.assert_close` to fail.  Instead of checking
375
    every element, we compute a cosine-style similarity over the whole tensor
376
    and report `1 - sim`.  Once kernel accuracy improves this helper can be
377
378
379
380
381
382
383
384
385
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


386
def should_use_deepgemm_for_fp8_linear(
387
388
    output_dtype: torch.dtype,
    weight: torch.Tensor,
389
    supports_deep_gemm: bool | None = None,
390
):
391
392
    if supports_deep_gemm is None:
        supports_deep_gemm = is_deep_gemm_supported()
393
394
395

    # Verify DeepGEMM N/K dims requirements
    # NOTE: Also synchronized with test_w8a8_block_fp8_deep_gemm_matmul
396
    # test inside kernels/quantization/test_block_fp8.py
397
398
399
    N_MULTIPLE = 64
    K_MULTIPLE = 128

400
401
402
    return (
        supports_deep_gemm
        and output_dtype == torch.bfloat16
403
404
        and weight.shape[0] % N_MULTIPLE == 0
        and weight.shape[1] % K_MULTIPLE == 0
405
    )
406
407


408
409
__all__ = [
    "calc_diff",
410
    "DeepGemmQuantScaleFMT",
411
412
413
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
414
415
416
    "fp8_mqa_logits",
    "fp8_paged_mqa_logits",
    "get_paged_mqa_logits_metadata",
417
    "per_block_cast_to_fp8",
418
    "is_deep_gemm_e8m0_used",
419
    "is_deep_gemm_supported",
420
    "get_num_sms",
421
    "should_use_deepgemm_for_fp8_linear",
422
    "get_col_major_tma_aligned_tensor",
423
    "get_mk_alignment_for_contiguous_layout",
424
]