deep_gemm.py 3.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations

import functools
import importlib
from typing import Any, Callable, NoReturn

import torch

import vllm.envs as envs
from vllm.utils import cuda_get_device_properties, has_deep_gemm


@functools.cache
def is_blackwell_deep_gemm_used() -> bool:
    """Return ``True`` if vLLM is configured to use DeepGEMM on a
    Blackwell-class GPU.
    """

    if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm()
            and _per_block_cast_impl is not None):
        return False

    return cuda_get_device_properties(0, ("major", ))[0] == 10


def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available. Please install the `deep_gemm` "
        "package to enable FP8 kernels.")


def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
    """Return the *new* symbol if it exists, otherwise the *old* one."""
    if hasattr(module, new):
        return getattr(module, new)
    if hasattr(module, old):
        return getattr(module, old)
    return None


if not has_deep_gemm():
    _fp8_gemm_nt_impl: Callable[..., Any] | None = None
    _grouped_impl: Callable[..., Any] | None = None
    _grouped_masked_impl: Callable[..., Any] | None = None
    _per_block_cast_impl: Callable[..., Any] | None = None
else:
    _dg = importlib.import_module("deep_gemm")  # type: ignore

    _fp8_gemm_nt_impl = _resolve_symbol(
        _dg,
        "fp8_gemm_nt",
        "gemm_fp8_fp8_bf16_nt",
    )
    _grouped_impl = _resolve_symbol(
        _dg,
        "m_grouped_fp8_gemm_nt_contiguous",
        "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
    )
    _grouped_masked_impl = _resolve_symbol(
        _dg,
        "fp8_m_grouped_gemm_nt_masked",
        "m_grouped_gemm_fp8_fp8_bf16_nt_masked",
    )

    # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
    try:
        _math_mod = importlib.import_module(
            "deep_gemm.utils.math")  # type: ignore
        _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8",
                                       None)
    except ModuleNotFoundError:
        _per_block_cast_impl = None


def fp8_gemm_nt(*args, **kwargs):
    if _fp8_gemm_nt_impl is None:
        return _missing(*args, **kwargs)
    return _fp8_gemm_nt_impl(*args, **kwargs)


def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
    if _grouped_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_impl(*args, **kwargs)


def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
    if _grouped_masked_impl is None:
        return _missing(*args, **kwargs)
    return _grouped_masked_impl(*args, **kwargs)


def per_block_cast_to_fp8(x, *args, **kwargs):
    if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
        return _per_block_cast_impl(x)
    # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
    from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf
    return _pbcf(x, *args, **kwargs)


def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing ``torch.testing.assert_close`` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report ``1 - sim``.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim


__all__ = [
    "calc_diff",
    "fp8_gemm_nt",
    "m_grouped_fp8_gemm_nt_contiguous",
    "fp8_m_grouped_gemm_nt_masked",
    "per_block_cast_to_fp8",
    "is_blackwell_deep_gemm_used",
]