deep_gemm.py 3.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
from typing import Any, Callable, NoReturn

import torch

import vllm.envs as envs
from vllm.utils import cuda_get_device_properties, has_deep_gemm


@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """

    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
            and _per_block_cast_impl is not None):
        return False

    return cuda_get_device_properties(0, ("major", ))[0] == 10


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None


if not has_deep_gemm():
    _fp8_gemm_nt_impl: Callable[..., Any] | None = None
    _grouped_impl: Callable[..., Any] | None = None
    _grouped_masked_impl: Callable[..., Any] | None = None
    _per_block_cast_impl: Callable[..., Any] | None = None
else:
    _dg = importlib.import_module("deep_gemm")  # type: ignore

    _fp8_gemm_nt_impl = _resolve_symbol(
        _dg,
        "fp8_gemm_nt",
        "gemm_fp8_fp8_bf16_nt",
    )
    _grouped_impl = _resolve_symbol(
        _dg,
        "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
    )
    _grouped_masked_impl = _resolve_symbol(
        _dg,
        "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked",
    )

    # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
    try:
        _math_mod = importlib.import_module(
            "deep_gemm.utils.math")  # type: ignore
        _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
                                       None)
    except ModuleNotFoundError:
        _per_block_cast_impl = None


def fp8_gemm_nt(*args, **kwargs):
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)


def per_block_cast_to_fp8(x, *args, **kwargs):
    if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
102
        return _per_block_cast_impl(x, use_ue8m0=True)
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
    from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
    return _pbcf(x, *args, **kwargs)


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
]