"vscode:/vscode.git/clone" did not exist on "9f7b4ba86578fbb0b6e80a2b0c1a334d88787a57"
Unverified Commit 2800706f authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Move NVFP4 GEMM management into NvFp4LinearKernel (#39129)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 0d310ffb
......@@ -55,6 +55,27 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4 import (
NvFp4LinearKernel,
NvFp4LinearLayerConfig,
)
from vllm.model_executor.kernels.linear.nvfp4.cutlass import (
CutlassNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4.emulation import (
EmulationNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4.fbgemm import (
FbgemmNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4.flashinfer import (
FlashInferCudnnNvFp4LinearKernel,
FlashInferCutlassNvFp4LinearKernel,
FlashInferTrtllmNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4.marlin import (
MarlinNvFp4LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
......@@ -180,6 +201,22 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
],
}
# in priority/performance order (when available)
_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferCutlassNvFp4LinearKernel,
CutlassNvFp4LinearKernel,
MarlinNvFp4LinearKernel,
FlashInferTrtllmNvFp4LinearKernel,
FlashInferCudnnNvFp4LinearKernel,
FbgemmNvFp4LinearKernel,
EmulationNvFp4LinearKernel,
],
PlatformEnum.ROCM: [
EmulationNvFp4LinearKernel,
],
}
# TODO make all kernels inherit from MMLinearKernel
# then bound _KernelT only to MMLinearKernel
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
......@@ -426,6 +463,88 @@ def choose_mp_linear_kernel(
)
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
"cutlass": CutlassNvFp4LinearKernel,
"marlin": MarlinNvFp4LinearKernel,
"flashinfer-trtllm": FlashInferTrtllmNvFp4LinearKernel,
"flashinfer-cudnn": FlashInferCudnnNvFp4LinearKernel,
"emulation": EmulationNvFp4LinearKernel,
}
def init_nvfp4_linear_kernel() -> NvFp4LinearKernel:
"""Select and instantiate the best NVFP4 linear kernel for the
current platform."""
config = NvFp4LinearLayerConfig()
# Env-var overrides.
force_kernel: type[NvFp4LinearKernel] | None = None
if envs.VLLM_USE_FBGEMM:
force_kernel = FbgemmNvFp4LinearKernel
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
force_kernel = EmulationNvFp4LinearKernel
elif envs.VLLM_NVFP4_GEMM_BACKEND is not None:
backend_name = envs.VLLM_NVFP4_GEMM_BACKEND
force_kernel = _NVFP4_BACKEND_TO_KERNEL.get(backend_name)
if force_kernel is None:
raise ValueError(
f"Unknown VLLM_NVFP4_GEMM_BACKEND={backend_name!r}. "
f"Valid choices: {list(_NVFP4_BACKEND_TO_KERNEL.keys())}"
)
if force_kernel is not None:
is_supported, reason = force_kernel.is_supported()
if not is_supported:
raise ValueError(
f"Forced NVFP4 kernel {force_kernel.__name__} is not "
f"supported: {reason}"
)
logger.info_once("Using %s for NVFP4 GEMM", force_kernel.__name__)
return force_kernel(config)
# Auto-select from registry.
platform = current_platform._enum
possible = _POSSIBLE_NVFP4_KERNELS.get(platform, [])
failure_reasons = []
for kernel_cls in possible:
if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel_cls.__name__} disabled by environment variable"
)
continue
is_supported, reason = kernel_cls.is_supported()
if not is_supported:
failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
continue
can_implement, reason = kernel_cls.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
continue
if kernel_cls is EmulationNvFp4LinearKernel and failure_reasons:
logger.warning_once(
"NVFP4 linear falling back to the slow and unoptimized "
"emulation backend as no optimized backend is available "
"(unavailable reasons:\n - %s\n). "
"In case you expect one of these backends to be used, "
"please verify your environment.",
"\n - ".join(failure_reasons),
)
logger.info_once("Using %s for NVFP4 GEMM", kernel_cls.__name__)
return kernel_cls(config)
raise ValueError(
"Failed to find a kernel that can implement the "
"NVFP4 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)
def register_linear_kernel(
kernel_class: type,
platform: PlatformEnum,
......@@ -455,6 +574,10 @@ def register_linear_kernel(
if platform not in _POSSIBLE_FP8_KERNELS:
_POSSIBLE_FP8_KERNELS[platform] = []
_POSSIBLE_FP8_KERNELS[platform].append(kernel_class)
elif kernel_type == "nvfp4":
if platform not in _POSSIBLE_NVFP4_KERNELS:
_POSSIBLE_NVFP4_KERNELS[platform] = []
_POSSIBLE_NVFP4_KERNELS[platform].append(kernel_class)
else:
raise ValueError(f"Unrecognized kernel type: {kernel_type}")
......@@ -462,6 +585,7 @@ def register_linear_kernel(
__all__ = [
"init_fp8_linear_kernel",
"init_int8_linear_kernel",
"init_nvfp4_linear_kernel",
"choose_mp_linear_kernel",
"register_linear_kernel",
"FP8ScaledMMLinearKernel",
......@@ -470,6 +594,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
"NvFp4LinearKernel",
"NvFp4LinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
"CPUInt8ScaledMMLinearKernel",
"CutlassFP8ScaledMMLinearKernel",
......@@ -492,6 +618,13 @@ __all__ = [
"MarlinLinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
"CutlassNvFp4LinearKernel",
"EmulationNvFp4LinearKernel",
"FbgemmNvFp4LinearKernel",
"FlashInferCutlassNvFp4LinearKernel",
"FlashInferTrtllmNvFp4LinearKernel",
"FlashInferCudnnNvFp4LinearKernel",
"MarlinNvFp4LinearKernel",
"_KernelT",
"DeepGemmFp8BlockScaledMMKernel",
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.nvfp4.base import (
NvFp4LinearKernel,
NvFp4LinearLayerConfig,
)
__all__ = [
"NvFp4LinearKernel",
"NvFp4LinearLayerConfig",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
@dataclass
class NvFp4LinearLayerConfig:
"""Configuration for an NVFP4 linear layer.
All NVFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
byte), FP8-E4M3 per-block weight scales (group size 16), and scalar global
scales for both weights and activations.
"""
pass
class NvFp4LinearKernel(ABC):
"""Base class for NVFP4 quantized linear kernels.
Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
The kernel selection mechanism iterates over registered subclasses in
priority order,calling ``is_supported`` and ``can_implement`` to find the best
match for the current hardware.
"""
def __init__(self, config: NvFp4LinearLayerConfig) -> None:
assert self.can_implement(config)[0]
assert self.is_supported()[0]
self.config = config
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
"""Return whether this kernel can run on the current platform."""
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
"""Return whether this kernel can handle *config*."""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Transform weights into the format required by this kernel.
Called once after checkpoint weights have been loaded onto the
device. Implementations should repack / swizzle / pad weights
and scales in-place on *layer*.
"""
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Run the quantized GEMM."""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import (
cutlass_scaled_fp4_mm,
scaled_fp4_quant,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
class CutlassNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via the vLLM CUTLASS kernel."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if cutlass_fp4_supported():
return True, None
return False, "CUTLASS FP4 kernels not available"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.weight_scale.data), requires_grad=False
)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="cutlass",
)
x_fp4 = pad_nvfp4_activation_for_cutlass(
x_fp4, getattr(layer, "weights_padding_cols", 0)
)
out = cutlass_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
run_nvfp4_emulations,
)
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
class EmulationNvFp4LinearKernel(NvFp4LinearKernel):
"""Software emulation fallback for NVFP4 (dequant → BF16 matmul)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
# Always available as a last-resort fallback.
return True, None
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Move the E2M1 lookup table to the device now, because
# `.to(device)` is not allowed during CUDA graph capture.
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out = run_nvfp4_emulations(
x=x,
input_global_scale=layer.input_global_scale_inv,
weight=layer.weight,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
swizzle=False,
)
if bias is not None:
out = out + bias
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.utils.import_utils import has_fbgemm_gpu
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
class FbgemmNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FBGEMM."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if has_fbgemm_gpu():
return True, None
return False, "fbgemm_gpu required"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
swizzled = swizzle_blockscale(layer.weight_scale.data)
layer.weight_scale = torch.nn.Parameter(
swizzled.view(-1).view(torch.uint8), requires_grad=False
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
import fbgemm_gpu # noqa: F401 - registers torch.ops.fbgemm.*
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="fbgemm",
)
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
layer.weight,
x_blockscale.view(-1).view(torch.uint8),
layer.weight_scale,
layer.alpha,
use_mx=False,
).to(output_dtype)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
pad_nvfp4_activation_for_cutlass,
pad_nvfp4_weight_for_cutlass,
slice_nvfp4_output,
swizzle_blockscale,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
class FlashInferCutlassNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FlashInfer's CUTLASS wrapper."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
if (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
):
return True, None
return False, "FlashInfer + >=sm_100 required"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.weight_scale.data), requires_grad=False
)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="flashinfer-cutlass",
)
x_fp4 = pad_nvfp4_activation_for_cutlass(
x_fp4, getattr(layer, "weights_padding_cols", 0)
)
out = flashinfer_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
backend="cutlass",
)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
class FlashInferTrtllmNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FlashInfer's TensorRT-LLM wrapper."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if has_flashinfer():
return True, None
return False, "FlashInfer required"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
layer.weight = torch.nn.Parameter(
shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m),
requires_grad=False,
)
layer.weight_scale = torch.nn.Parameter(
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn),
requires_grad=False,
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="flashinfer-trtllm",
)
out = flashinfer_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
backend="trtllm",
)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
class FlashInferCudnnNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 GEMM via FlashInfer's cuDNN wrapper."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if has_flashinfer():
return True, None
return False, "FlashInfer required"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# cuDNN uses the same swizzled + padded layout as CUTLASS
layer.weight_scale = torch.nn.Parameter(
swizzle_blockscale(layer.weight_scale.data), requires_grad=False
)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(
layer.weight.data
)
layer.weight = torch.nn.Parameter(padded_weight, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
output_size = layer.output_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
x_fp4, x_blockscale = scaled_fp4_quant(
x,
layer.input_global_scale_inv,
is_sf_swizzled_layout=True,
backend="flashinfer-cudnn",
)
x_fp4 = pad_nvfp4_activation_for_cutlass(
x_fp4, getattr(layer, "weights_padding_cols", 0)
)
out = flashinfer_scaled_fp4_mm(
x_fp4,
layer.weight,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
backend="cudnn",
)
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
)
from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
logger = init_logger(__name__)
class MarlinNvFp4LinearKernel(NvFp4LinearKernel):
"""NVFP4 weight-only GEMM via Marlin (W4A16)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if is_fp4_marlin_supported():
return True, None
return False, "Marlin FP4 not available"
@classmethod
def can_implement(cls, config: NvFp4LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
......@@ -6,15 +6,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_nvfp4_linear_kernel
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
......@@ -29,13 +24,9 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
self.backend = select_nvfp4_linear_backend()
self.kernel = init_nvfp4_linear_kernel()
self.group_size = 16
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
@classmethod
def get_min_capability(cls) -> int:
return 75
......@@ -130,7 +121,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
)
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self,
......@@ -138,10 +129,4 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_nvfp4_linear(
backend=self.backend,
layer=layer,
x=x,
bias=bias,
swizzle=self.swizzle,
)
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)
......@@ -10,7 +10,10 @@ from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
init_nvfp4_linear_kernel,
)
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
......@@ -70,12 +73,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
Mxfp8LinearOp,
mxfp8_e4m3_quantize,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
......@@ -1090,11 +1087,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
self.marlin_input_dtype = None
self.backend = select_nvfp4_linear_backend()
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
self.kernel = init_nvfp4_linear_kernel()
def create_weights(
self,
......@@ -1201,7 +1194,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format(self.backend, layer)
self.kernel.process_weights_after_loading(layer)
def apply(
self,
......@@ -1209,13 +1202,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_nvfp4_linear(
backend=self.backend,
layer=layer,
x=x,
bias=bias,
swizzle=self.swizzle,
)
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
from vllm._custom_ops import (
cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4,
scaled_fp4_quant,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
is_fp4_marlin_supported,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
run_nvfp4_emulations,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.import_utils import has_fbgemm_gpu
from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
# NOTE: This is ordered by preferred backend.
# Example: if both are available, FLASHINFER_CUTLASS is preferred to VLLM_CUTLASS.
class NvFp4LinearBackend(Enum):
FLASHINFER_CUTLASS = "flashinfer-cutlass"
VLLM_CUTLASS = "cutlass"
MARLIN = "marlin"
FLASHINFER_TRTLLM = "flashinfer-trtllm"
FLASHINFER_CUDNN = "flashinfer-cudnn"
FBGEMM = "fbgemm"
EMULATION = "emulation"
NVFP4_LINEAR_BACKENDS = list(NvFp4LinearBackend)
def is_backend_supported(backend: NvFp4LinearBackend) -> tuple[bool, str | None]:
reason = None
supported = True
if backend == NvFp4LinearBackend.FLASHINFER_CUTLASS:
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
# quantization and GEMM) were compiled for the current SM version.
# FlashInfer backends still rely on the vLLM quantization kernels,
# so we gate them on the same check.
supported = (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
)
if not supported:
reason = "FlashInfer is required, >=sm_100 is required"
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
supported = cutlass_fp4_supported()
if not supported:
reason = "Cutlass is required"
elif backend == NvFp4LinearBackend.MARLIN:
supported = is_fp4_marlin_supported()
if not supported:
reason = "Marlin is required"
elif backend in [
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
]:
supported = has_flashinfer()
if not supported:
reason = "FlashInfer is required"
elif backend == NvFp4LinearBackend.FBGEMM:
supported = has_fbgemm_gpu()
if not supported:
reason = "fbgemm_gpu is required"
elif backend == NvFp4LinearBackend.EMULATION:
# e.g. AMD Instinct does not support native NVFP4.
unsupported_reasons = {}
for other_backend in NVFP4_LINEAR_BACKENDS:
if other_backend == NvFp4LinearBackend.EMULATION:
continue
other_supported, other_reason = is_backend_supported(other_backend)
if not other_supported:
unsupported_reasons[other_backend] = other_reason
if unsupported_reasons:
unsupported_reasons_str = "\n - ".join(
[f"{b.value}: {r}" for b, r in unsupported_reasons.items()]
)
logger.warning_once(
f"NVFP4 linear falling back to the slow and unoptimized "
f"backend=NvFp4LinearBackend.EMULATION as no optimized backend is "
f"available (unavailable reasons:\n - {unsupported_reasons_str}\n). "
"In case you expect one of these backend to be used, "
"please verify your environment."
)
return supported, reason
def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
"""
Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities.
"""
if envs.VLLM_BATCH_INVARIANT:
logger.info_once(
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the emulation "
"backend for deterministic execution."
)
return NvFp4LinearBackend.EMULATION
selected_backend: NvFp4LinearBackend | None = None
if envs.VLLM_USE_FBGEMM:
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
selected_backend = NvFp4LinearBackend.FBGEMM
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
selected_backend = NvFp4LinearBackend.EMULATION
elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
for backend in NVFP4_LINEAR_BACKENDS:
supported, reason = is_backend_supported(backend)
if supported:
selected_backend = backend
break
else:
selected_backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
if selected_backend is None:
raise ValueError(
f"No NVFP4 GEMM backend selected, "
f"available backends: {NVFP4_LINEAR_BACKENDS}"
)
supported, reason = is_backend_supported(selected_backend)
if not supported:
raise ValueError(
f"The selected backend={selected_backend} is not supported in current "
f"environment. Reason: {reason}. Current environment: "
f"{envs.VLLM_USE_FBGEMM=}, {envs.VLLM_USE_NVFP4_CT_EMULATIONS=}, "
f"{envs.VLLM_NVFP4_GEMM_BACKEND}."
)
logger.info_once(f"Using {selected_backend} for NVFP4 GEMM")
return selected_backend
def prepare_weights_for_nvfp4_flashinfer_trtllm(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM."""
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
epilogue_tile_m = 128
shuffled_weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
shuffled_weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)
return shuffled_weight, shuffled_weight_scale
def prepare_weights_for_nvfp4_cutlass(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM.
This involves padding weights for alignment (K and N divisible by 32)
"""
swizzled_weight_scale = swizzle_blockscale(weight_scale)
padded_weight, weights_padding_cols = pad_nvfp4_weight_for_cutlass(weight)
return padded_weight, swizzled_weight_scale, weights_padding_cols
def prepare_weights_for_nvfp4_fbgemm(
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare weights and scales for FBGEMM FP4 GEMM."""
swizzled_weight_scale = swizzle_blockscale(weight_scale)
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
return weight, swizzled_weight_scale
def convert_to_nvfp4_linear_kernel_format(
backend: NvFp4LinearBackend,
layer: torch.nn.Module,
) -> None:
"""Convert layer to NVFP4 linear kernel format."""
assert layer.weight_scale.dtype == torch.float8_e4m3fn, (
"Weight Block scale must be represented as FP8-E4M3"
)
# Default to no padding
layer.weights_padding_cols = 0
if backend == NvFp4LinearBackend.MARLIN:
logger.warning_once(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_fp4_layer_for_marlin(layer)
elif backend == NvFp4LinearBackend.FLASHINFER_TRTLLM:
weight, weight_scale = prepare_weights_for_nvfp4_flashinfer_trtllm(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif backend == NvFp4LinearBackend.FBGEMM:
weight, weight_scale = prepare_weights_for_nvfp4_fbgemm(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
elif backend in (
NvFp4LinearBackend.VLLM_CUTLASS,
NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_CUDNN,
):
weight, weight_scale, weights_padding_cols = prepare_weights_for_nvfp4_cutlass(
layer.weight.data, layer.weight_scale.data
)
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols
elif backend == NvFp4LinearBackend.EMULATION:
# We can not call `.to(device)` during cuda graph capture - do it here instead.
# (operation not permitted when stream is capturing)
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)
def apply_nvfp4_linear(
backend: NvFp4LinearBackend,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
swizzle: bool | None = None,
) -> torch.Tensor:
"""
Apply NVFP4 linear transformation using the specified backend.
"""
weight = layer.weight
weight_scale = layer.weight_scale
weight_global_scale = layer.weight_global_scale
input_global_scale_inv = layer.input_global_scale_inv
alpha = layer.alpha
output_size = layer.output_size_per_partition
input_size = layer.input_size_per_partition
output_dtype = x.dtype
output_shape = [*x.shape[:-1], output_size]
if backend == NvFp4LinearBackend.MARLIN:
return apply_fp4_marlin_linear(
input=x,
weight=weight,
weight_scale=weight_scale,
weight_global_scale=weight_global_scale,
workspace=layer.workspace,
size_n=output_size,
size_k=input_size,
bias=bias,
)
elif backend == NvFp4LinearBackend.EMULATION:
x_2d = x.reshape(-1, x.shape[-1])
out = run_nvfp4_emulations(
x=x_2d,
input_global_scale=input_global_scale_inv,
weight=weight,
weight_scale_swizzled=weight_scale,
weight_global_scale=weight_global_scale,
swizzle=swizzle,
)
out = out[:, :output_size]
if bias is not None:
out = out + bias
return out.view(*output_shape)
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, input_global_scale_inv, is_sf_swizzled_layout=True, backend=backend.value
)
# Validate dtypes
assert x_fp4.dtype == torch.uint8
assert weight.dtype == torch.uint8
assert x_blockscale.dtype == torch.float8_e4m3fn
# weight_scale is fp8 for most backends, but uint8 for fbgemm
assert weight_scale.dtype in (torch.float8_e4m3fn, torch.uint8)
assert alpha.dtype == torch.float32
# Pad activations to match weight K-dimension padding
weights_padding_cols = getattr(layer, "weights_padding_cols", 0)
x_fp4 = pad_nvfp4_activation_for_cutlass(x_fp4, weights_padding_cols)
# Prepare args for the matmul
mm_args = (
x_fp4,
weight,
x_blockscale,
weight_scale,
alpha,
output_dtype,
)
# Call the appropriate backend
if backend.value.startswith("flashinfer-"):
backend_name = backend.value[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif backend == NvFp4LinearBackend.FBGEMM:
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
weight,
x_blockscale.view(-1).view(torch.uint8),
weight_scale,
alpha,
use_mx=False,
).to(output_dtype)
else:
assert backend == NvFp4LinearBackend.VLLM_CUTLASS
out = cutlass_scaled_fp4_mm(*mm_args)
# Slice output to remove N-dimension padding
out = slice_nvfp4_output(out, output_size)
if bias is not None:
out = out + bias
return out.view(*output_shape)
def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment