deep_gemm.py 6.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
11
import os
12
from typing import Any, Callable, NoReturn
13
14
15
16

import torch

import vllm.envs as envs
17
from vllm.logger import logger
18
from vllm.platforms import current_platform
19
from vllm.utils import cdiv, has_deep_gemm
20
21


22
23
24
25
26
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
27
    is_supported_arch = current_platform.is_cuda() and (
28
29
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
30
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
31
32


33
@functools.cache
34
def is_deep_gemm_e8m0_used() -> bool:
35
    """Return ``True`` if vLLM is configured to use DeepGEMM "
36
    "E8M0 scale on a Hopper or Blackwell-class GPU.
37
    """
38
    if not is_deep_gemm_supported():
39
        logger.debug_once(
40
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.")
41
42
        return False

43
    _lazy_init()
44

45
    if _fp8_gemm_nt_impl is None:
46
        logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
47
48
        return False

49
50
51
52
53
54
55
56
57
58
59
60
    if current_platform.is_device_capability(100) and \
            envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
        return True

    if current_platform.is_device_capability(90) and \
            envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER:
        logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.")
        return True

    logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
    return False
61
62
63
64
65
66
67
68
69


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


70
71
72
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
73
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
74
75
76
77


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
78
79
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\
         _get_mn_major_tma_aligned_tensor_impl
80
81
82

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
83
            or _grouped_masked_impl is not None):
84
85
86
87
88
        return

    if not has_deep_gemm():
        return

89
90
91
92
93
94
    # Set up deep_gemm cache path
    DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR'
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
            envs.VLLM_CACHE_ROOT, "deep_gemm")

95
96
    _dg = importlib.import_module("deep_gemm")

97
98
99
    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
100
101
102
103
104
105
106
107
108
109
    _get_mn_major_tma_aligned_tensor_impl = getattr(
        _dg, "get_mn_major_tma_aligned_tensor", None)


def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)
110
111
112


def fp8_gemm_nt(*args, **kwargs):
113
    _lazy_init()
114
115
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
116
117
118
    return _fp8_gemm_nt_impl(*args,
                             disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
                             **kwargs)
119
120
121


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
122
    _lazy_init()
123
124
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
125
126
127
    return _grouped_impl(*args,
                         disable_ue8m0_cast=not is_deep_gemm_e8m0_used(),
                         **kwargs)
128
129
130


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
131
    _lazy_init()
132
133
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
134
    return _grouped_masked_impl(
135
        *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
136
137


138
139
140
141
142
143
144
145
146
147
148
149
def _ceil_to_ue8m0(x: torch.Tensor):
    return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))


def _align(x: int, y: int) -> int:
    return cdiv(x, y) * y


DEFAULT_BLOCK_SIZE = [128, 128]


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
150
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def per_block_cast_to_fp8(
        x: torch.Tensor,
        block_size: list[int] = DEFAULT_BLOCK_SIZE,
        use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    block_m, block_n = block_size
    x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    sf = x_amax / 448.0
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
    x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
        x_view.size(0), x_view.size(2))
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


187
188
189
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
                                       weight: torch.Tensor):
    return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
190
191
192
            and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)


193
194
195
196
197
198
__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
199
    "is_deep_gemm_e8m0_used",
200
    "is_deep_gemm_supported",
201
    "should_use_deepgemm_for_fp8_linear",
202
    "get_col_major_tma_aligned_tensor",
203
]