# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Compatibility wrapper for DeepGEMM API changes. Users of vLLM should always import **only** these wrappers. """ from __future__ import annotations import functools import importlib from typing import Any, Callable, NoReturn import torch import vllm.envs as envs from vllm.utils import cuda_get_device_properties, has_deep_gemm @functools.cache def is_blackwell_deep_gemm_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM on a Blackwell-class GPU. """ if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and _per_block_cast_impl is not None): return False return cuda_get_device_properties(0, ("major", ))[0] == 10 def _missing(*_: Any, **__: Any) -> NoReturn: """Placeholder for unavailable DeepGEMM backend.""" raise RuntimeError( "DeepGEMM backend is not available. Please install the `deep_gemm` " "package to enable FP8 kernels.") def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: """Return the *new* symbol if it exists, otherwise the *old* one.""" if hasattr(module, new): return getattr(module, new) if hasattr(module, old): return getattr(module, old) return None if not has_deep_gemm(): _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None _per_token_cast_impl: Callable[..., Any] | None = None _per_block_cast_impl: Callable[..., Any] | None = None else: _dg = importlib.import_module("deep_gemm") # type: ignore _fp8_gemm_nt_impl = _resolve_symbol( _dg, "fp8_gemm_nt", "gemm_fp8_fp8_bf16_nt", ) _grouped_impl = _resolve_symbol( _dg, "m_grouped_fp8_gemm_nt_contiguous", "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous", ) _grouped_masked_impl = _resolve_symbol( _dg, "fp8_m_grouped_gemm_nt_masked", "m_grouped_gemm_fp8_fp8_bf16_nt_masked", ) # Try to get per_token_cast_to_fp8 from DeepGEMM math utils. try: _math_mod = importlib.import_module( "deep_gemm.utils.math") # type: ignore _per_token_cast_impl = getattr(_math_mod, "per_token_cast_to_fp8", None) _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8", None) except ModuleNotFoundError: _per_token_cast_impl = None _per_block_cast_impl = None def fp8_gemm_nt(*args, **kwargs): if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) return _fp8_gemm_nt_impl(*args, **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): if _grouped_impl is None: return _missing(*args, **kwargs) return _grouped_impl(*args, **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl(*args, **kwargs) def per_token_group_cast_to_fp8(x, group_size, *args, **kwargs): """Wrapper for token-wise FP8 quantisation. • If DeepGEMM provides ``per_token_cast_to_fp8`` (new API), use it. • Otherwise, fall back to vLLM's ``per_token_group_quant_fp8`` """ if _per_token_cast_impl is not None and is_blackwell_deep_gemm_used(): assert group_size == 128, "group_size must be 128 for deepgemm" return _per_token_cast_impl(x) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8 as _ptg) return _ptg(x, group_size, *args, **kwargs) def per_block_cast_to_fp8(x, *args, **kwargs): if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): return _per_block_cast_impl(x) # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf return _pbcf(x, *args, **kwargs) def calc_diff(x: torch.Tensor, y: torch.Tensor): """Return a global difference metric for unit tests. DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing ``torch.testing.assert_close`` to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report ``1 - sim``. Once kernel accuracy improves this helper can be removed. """ x, y = x.double(), y.double() denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return 1 - sim __all__ = [ "calc_diff", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_token_group_cast_to_fp8", "per_block_cast_to_fp8", "is_blackwell_deep_gemm_used", ]