deep_gemm.py 5.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
11
import os
12
13
14
15
16
from typing import Any, Callable, NoReturn

import torch

import vllm.envs as envs
17
from vllm.platforms import current_platform
18
from vllm.utils import cdiv, has_deep_gemm
19
20


21
22
23
24
25
26
27
28
29
30
31
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
    return has_deep_gemm() and supported_arch


32
33
34
35
36
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """
37
38
    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
        return False
39

40
    _lazy_init()
41
    if _fp8_gemm_nt_impl is None:
42
43
        return False

44
45
    return (current_platform.is_cuda()
            and current_platform.is_device_capability(100))
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None


64
65
66
67
68
69
70
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
71
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
72
73
74

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
75
            or _grouped_masked_impl is not None):
76
77
78
79
80
        return

    if not has_deep_gemm():
        return

81
82
83
84
85
86
    # Set up deep_gemm cache path
    DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
            envs.VLLM_CACHE_ROOT, "deep_gemm")

87
88
89
90
    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
                                        "gemm_fp8_fp8_bf16_nt")
91
    _grouped_impl = _resolve_symbol(
92
93
        _dg, "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
94
    _grouped_masked_impl = _resolve_symbol(
95
96
        _dg, "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked")
97
98
99


def fp8_gemm_nt(*args, **kwargs):
100
    _lazy_init()
101
102
103
104
105
106
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
107
    _lazy_init()
108
109
110
111
112
113
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
114
    _lazy_init()
115
116
117
118
119
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)


120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
# TODO(wentao): optimize this function, using triton or cuda kernel
def per_block_cast_to_fp8(
        x: torch.Tensor,
        block_size: list[int] = DEFAULT_BLOCK_SIZE,
        use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
    x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
        x_view.size(0), x_view.size(2))
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
176
    "is_deep_gemm_supported",
177
]