deep_gemm.py 11.1 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
7

8
9
10
11
from __future__ import annotations

import functools
import importlib
12
import os
13
from typing import Any, Callable, NoReturn
14
15
16
17

import torch

import vllm.envs as envs
18
from vllm.logger import logger
19
from vllm.platforms import current_platform
20
from vllm.utils import cdiv, has_deep_gemm
21
22


23
24
25
26
27
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
28
    is_supported_arch = current_platform.is_cuda() and (
29
        current_platform.is_device_capability(90)
30
31
        or current_platform.is_device_capability(100)
    )
32
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
33
34


35
@functools.cache
36
def is_deep_gemm_e8m0_used() -> bool:
37
    """Return ``True`` if vLLM is configured to use DeepGEMM "
38
    "E8M0 scale on a Hopper or Blackwell-class GPU.
39
    """
40
    if not is_deep_gemm_supported():
41
        logger.debug_once(
42
43
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
        )
44
45
        return False

46
    _lazy_init()
47

48
    if _fp8_gemm_nt_impl is None:
49
        logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
50
51
        return False

52
53
54
55
    if envs.VLLM_USE_FLASHINFER_MOE_FP8:
        logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
        return False

56
57
    if envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.info_once("DeepGEMM E8M0 enabled on current platform.")
58
59
60
61
        return True

    logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
    return False
62
63
64
65
66


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
67
        "DeepGEMM backend is not available or outdated. Please install or "
68
69
        "update the `deep_gemm` to a newer version to enable FP8 kernels."
    )
70
71


72
73
74
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
75
76
77
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
78
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
79
80
81
82


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
83
84
85
86
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
    global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
    global _get_paged_mqa_logits_metadata_impl
    global _get_mn_major_tma_aligned_tensor_impl
87
88

    # fast path
89
90
91
92
93
94
95
96
    if (
        _fp8_gemm_nt_impl is not None
        or _grouped_impl is not None
        or _grouped_masked_impl is not None
        or _fp8_mqa_logits_impl is not None
        or _fp8_paged_mqa_logits_impl is not None
        or _get_paged_mqa_logits_metadata_impl is not None
    ):
97
98
99
100
101
        return

    if not has_deep_gemm():
        return

102
    # Set up deep_gemm cache path
103
    DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
104
105
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
106
107
            envs.VLLM_CACHE_ROOT, "deep_gemm"
        )
108

109
110
    _dg = importlib.import_module("deep_gemm")

111
112
113
    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
114
115
116
    _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
    _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
    _get_paged_mqa_logits_metadata_impl = getattr(
117
118
        _dg, "get_paged_mqa_logits_metadata", None
    )
119
    _get_mn_major_tma_aligned_tensor_impl = getattr(
120
121
        _dg, "get_mn_major_tma_aligned_tensor", None
    )
122
123


124
125
126
127
128
129
def get_num_sms() -> int:
    _lazy_init()
    _dg = importlib.import_module("deep_gemm")
    return int(_dg.get_num_sms())


130
131
132
133
134
135
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)
136
137
138


def fp8_gemm_nt(*args, **kwargs):
139
    _lazy_init()
140
141
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
142
143
144
145
146
147
    if "is_deep_gemm_e8m0_used" in kwargs:
        use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
        del kwargs["is_deep_gemm_e8m0_used"]
    else:
        use_ue8m0 = is_deep_gemm_e8m0_used()
    return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
148
149
150


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
151
    _lazy_init()
152
153
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
154
155
156
    return _grouped_impl(
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
157
158
159


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
160
    _lazy_init()
161
162
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
163
    return _grouped_masked_impl(
164
165
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
    )
166
167


168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
            [N, 1]) with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
    if _fp8_mqa_logits_impl is None:
        return _missing()
    return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)


198
199
200
def get_paged_mqa_logits_metadata(
    context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
        Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
216
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


def fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
    if _fp8_paged_mqa_logits_impl is None:
        return _missing()
252
253
254
255
256
257
258
259
260
261
    return _fp8_paged_mqa_logits_impl(
        q_fp8,
        kv_cache_fp8,
        weights,
        context_lens,
        block_tables,
        schedule_metadata,
        max_model_len,
        clean_logits=True,
    )
262
263


264
265
266
267
268
269
270
271
272
273
274
275
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
276
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
277
def per_block_cast_to_fp8(
278
279
    x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
280
281
282
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
283
284
285
    x_padded = torch.zeros(
        (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
    )
286
287
288
289
290
291
292
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
293
294
        x_view.size(0), x_view.size(2)
    )
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


313
def should_use_deepgemm_for_fp8_linear(
314
315
    output_dtype: torch.dtype,
    weight: torch.Tensor,
316
    supports_deep_gemm: bool | None = None,
317
):
318
319
    if supports_deep_gemm is None:
        supports_deep_gemm = is_deep_gemm_supported()
320
321
322
323
324
325
    return (
        supports_deep_gemm
        and output_dtype == torch.bfloat16
        and weight.shape[0] % 128 == 0
        and weight.shape[1] % 128 == 0
    )
326
327


328
329
330
331
332
__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
333
334
335
    "fp8_mqa_logits",
    "fp8_paged_mqa_logits",
    "get_paged_mqa_logits_metadata",
336
    "per_block_cast_to_fp8",
337
    "is_deep_gemm_e8m0_used",
338
    "is_deep_gemm_supported",
339
    "get_num_sms",
340
    "should_use_deepgemm_for_fp8_linear",
341
    "get_col_major_tma_aligned_tensor",
342
]