deep_gemm.py 4.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
from typing import Any, Callable, NoReturn

import torch

import vllm.envs as envs
16
17
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
18
19


20
21
22
23
24
25
26
27
28
29
30
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return ``True`` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability(100))
    return has_deep_gemm() and supported_arch


31
32
33
34
35
@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """
36
37
    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()):
        return False
38

39
40
    _lazy_init()
    if _per_block_cast_impl is None:
41
42
        return False

43
44
    return (current_platform.is_cuda()
            and current_platform.is_device_capability(100))
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None


def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
        _per_block_cast_impl

    # fast path
    if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
            or _grouped_masked_impl is not None
            or _per_block_cast_impl is not None):
        return

    if not has_deep_gemm():
        return

    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
                                        "gemm_fp8_fp8_bf16_nt")
87
    _grouped_impl = _resolve_symbol(
88
89
        _dg, "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
90
    _grouped_masked_impl = _resolve_symbol(
91
92
        _dg, "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked")
93
94
95
96
97
98
99
100
101
102
103
    # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
    try:
        _math_mod = importlib.import_module(
            "deep_gemm.utils.math")  # type: ignore
        _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
                                       None)
    except ModuleNotFoundError:
        _per_block_cast_impl = None


def fp8_gemm_nt(*args, **kwargs):
104
    _lazy_init()
105
106
107
108
109
110
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
111
    _lazy_init()
112
113
114
115
116
117
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
118
    _lazy_init()
119
120
121
122
123
124
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)


def per_block_cast_to_fp8(x, *args, **kwargs):
125
    _lazy_init()
126
    if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
127
        return _per_block_cast_impl(x, use_ue8m0=True)
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
    from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
    return _pbcf(x, *args, **kwargs)


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
156
    "is_deep_gemm_supported",
157
]