deep_gemm.py 5.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
from typing import Any, Callable, NoReturn

import torch

import vllm.envs as envs
16
from vllm.platforms import current_platform
17
from vllm.utils import cdiv, has_deep_gemm
18
19


20
21
22
23
24
25
26
27
28
29
30
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
    return has_deep_gemm() and supported_arch


31
32
33
34
35
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """
36
37
    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
        return False
38

39
    _lazy_init()
40
    if _fp8_gemm_nt_impl is None:
41
42
        return False

43
44
    return (current_platform.is_cuda()
            and current_platform.is_device_capability(100))
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None


63
64
65
66
67
68
69
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
70
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
71
72
73

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
74
            or _grouped_masked_impl is not None):
75
76
77
78
79
80
81
82
83
        return

    if not has_deep_gemm():
        return

    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
                                        "gemm_fp8_fp8_bf16_nt")
84
    _grouped_impl = _resolve_symbol(
85
86
        _dg, "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
87
    _grouped_masked_impl = _resolve_symbol(
88
89
        _dg, "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked")
90
91
92


def fp8_gemm_nt(*args, **kwargs):
93
    _lazy_init()
94
95
96
97
98
99
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
100
    _lazy_init()
101
102
103
104
105
106
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
107
    _lazy_init()
108
109
110
111
112
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)


113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
# TODO(wentao): optimize this function, using triton or cuda kernel
def per_block_cast_to_fp8(
        x: torch.Tensor,
        block_size: list[int] = DEFAULT_BLOCK_SIZE,
        use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
    x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
        x_view.size(0), x_view.size(2))
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
169
    "is_deep_gemm_supported",
170
]