Unverified Commit 00d7b497 authored by fxmarty-amd's avatar fxmarty-amd Committed by GitHub
Browse files

[NVFP4] Support NVFP4 dense models from `modelopt` and `compressed-tensors` on...


[NVFP4] Support NVFP4 dense models from `modelopt` and `compressed-tensors` on AMD Instinct MI300, MI355X and Hopper through emulation (#35733)
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
Signed-off-by: default avatarfxmarty-amd <felmarty@amd.com>
Co-authored-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 9c81f35b
...@@ -89,22 +89,33 @@ def test_models(example_prompts, model_name) -> None: ...@@ -89,22 +89,33 @@ def test_models(example_prompts, model_name) -> None:
EAGER = [True, False] EAGER = [True, False]
SM_100_NVFP4_BACKENDS = [
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
]
@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason="modelopt_fp4 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"]) @pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
@pytest.mark.parametrize("eager", EAGER) @pytest.mark.parametrize("eager", EAGER)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
[ [
"emulation",
"flashinfer-cudnn", "flashinfer-cudnn",
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used "flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer-cutlass", "flashinfer-cutlass",
], ],
) )
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch): def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
if (
not current_platform.has_device_capability(100)
and backend in SM_100_NVFP4_BACKENDS
):
pytest.skip(
f"The backend {backend} is not supported with current_platform.has_device_capability(100) == False"
)
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend) monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(model, enforce_eager=eager) as llm: with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2) output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
......
...@@ -366,9 +366,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner): ...@@ -366,9 +366,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert output assert output
@pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args", "args",
[ [
...@@ -398,7 +395,7 @@ def test_compressed_tensors_nvfp4(vllm_runner, args): ...@@ -398,7 +395,7 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
assert qkv_proj.scheme.group_size == 16 assert qkv_proj.scheme.group_size == 16
llm.apply_model(check_model) llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4) output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
print(output) print(output)
assert output assert output
......
...@@ -1464,6 +1464,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1464,6 +1464,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "flashinfer-trtllm": use flashinfer trtllm GEMM backend # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend
# - "flashinfer-cutlass": use flashinfer cutlass GEMM backend # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend
# - "marlin": use marlin GEMM backend (for GPUs without native FP4 support) # - "marlin": use marlin GEMM backend (for GPUs without native FP4 support)
# - "emulation":
# use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations.
# This is only meant for research purposes to run on devices where NVFP4
# GEMM kernels are not available.
# - <none>: automatically pick an available backend # - <none>: automatically pick an available backend
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices( "VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND", "VLLM_NVFP4_GEMM_BACKEND",
...@@ -1474,6 +1478,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1474,6 +1478,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"flashinfer-cutlass", "flashinfer-cutlass",
"cutlass", "cutlass",
"marlin", "marlin",
"emulation",
], ],
), ),
# Controls garbage collection during CUDA graph capture. # Controls garbage collection during CUDA graph capture.
......
...@@ -5,10 +5,12 @@ from collections.abc import Callable ...@@ -5,10 +5,12 @@ from collections.abc import Callable
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear, apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format, convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend, select_nvfp4_linear_backend,
...@@ -19,6 +21,9 @@ from vllm.model_executor.parameter import ( ...@@ -19,6 +21,9 @@ from vllm.model_executor.parameter import (
PerTensorScaleParameter, PerTensorScaleParameter,
) )
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"] __all__ = ["CompressedTensorsW4A4Fp4"]
...@@ -27,6 +32,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -27,6 +32,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
self.backend = select_nvfp4_linear_backend() self.backend = select_nvfp4_linear_backend()
self.group_size = 16 self.group_size = 16
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 75 return 75
...@@ -89,6 +98,19 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -89,6 +98,19 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# Rename CT checkpoint names to standardized names # Rename CT checkpoint names to standardized names
layer.weight = layer.weight_packed layer.weight = layer.weight_packed
del layer.weight_packed del layer.weight_packed
if (
torch.unique(layer.input_global_scale).numel() != 1
or torch.unique(layer.weight_global_scale).numel() != 1
):
logger.warning_once(
"In NVFP4 linear, the global scale for input or weight are different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely result in reduced accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for fused layers."
)
# Process global scales (CT stores as divisors, i.e. 1/scale) # Process global scales (CT stores as divisors, i.e. 1/scale)
input_global_scale_inv = layer.input_global_scale.max().to(torch.float32) input_global_scale_inv = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter( layer.input_global_scale = Parameter(
...@@ -121,4 +143,5 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -121,4 +143,5 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
layer=layer, layer=layer,
x=x, x=x,
bias=bias, bias=bias,
swizzle=self.swizzle,
) )
...@@ -71,6 +71,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -71,6 +71,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
NvFp4LinearBackend,
apply_nvfp4_linear, apply_nvfp4_linear,
convert_to_nvfp4_linear_kernel_format, convert_to_nvfp4_linear_kernel_format,
select_nvfp4_linear_backend, select_nvfp4_linear_backend,
...@@ -1074,6 +1075,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1074,6 +1075,10 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
self.marlin_input_dtype = None self.marlin_input_dtype = None
self.backend = select_nvfp4_linear_backend() self.backend = select_nvfp4_linear_backend()
self.swizzle = None
if self.backend == NvFp4LinearBackend.EMULATION:
self.swizzle = False
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1149,10 +1154,23 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1149,10 +1154,23 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if (
torch.unique(layer.input_scale).numel() != 1
or torch.unique(layer.weight_scale_2).numel() != 1
):
logger.warning_once(
"In NVFP4 linear, the global scale for input or weight are different"
" for parallel layers (e.g. q_proj, k_proj, v_proj). This "
" will likely results in reduce accuracy. Please verify the model"
" accuracy. Consider using a checkpoint with a shared global NVFP4"
" scale for parallel layers."
)
# Rename ModelOpt checkpoint names to standardized names # Rename ModelOpt checkpoint names to standardized names
input_global_scale = layer.input_scale.max().to(torch.float32) input_global_scale = layer.input_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) layer.input_global_scale = Parameter(input_global_scale, requires_grad=False)
del layer.input_scale del layer.input_scale
weight_global_scale = layer.weight_scale_2.max().to(torch.float32) weight_global_scale = layer.weight_scale_2.max().to(torch.float32)
layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False)
del layer.weight_scale_2 del layer.weight_scale_2
...@@ -1179,6 +1197,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1179,6 +1197,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer=layer, layer=layer,
x=x, x=x,
bias=bias, bias=bias,
swizzle=self.swizzle,
) )
......
...@@ -24,7 +24,7 @@ logger = init_logger(__name__) ...@@ -24,7 +24,7 @@ logger = init_logger(__name__)
def is_fp4_marlin_supported(): def is_fp4_marlin_supported():
return current_platform.has_device_capability(75) return current_platform.is_cuda() and current_platform.has_device_capability(75)
def _nvfp4_compute_scale_factor( def _nvfp4_compute_scale_factor(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import SimpleNamespace
import torch import torch
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -11,9 +13,10 @@ __all__ = [ ...@@ -11,9 +13,10 @@ __all__ = [
] ]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT4_E2M1_MAX_RECIPROCAL = 1 / FLOAT4_E2M1_MAX
kE2M1ToFloat = torch.tensor( kE2M1ToFloat_handle = SimpleNamespace(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 val=torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32)
) )
...@@ -29,8 +32,9 @@ def break_fp4_bytes(a, dtype): ...@@ -29,8 +32,9 @@ def break_fp4_bytes(a, dtype):
# Vectorized sign and magnitude extraction # Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) abs_vals = (combined & 0x07).to(torch.long)
kE2M1 = kE2M1ToFloat_handle.val
# Device-aware lookup and sign application # Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form # Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype) return values.reshape(m, n * 2).to(dtype=dtype)
...@@ -47,7 +51,12 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): ...@@ -47,7 +51,12 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
def dequantize_to_dtype( def dequantize_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 tensor_fp4: torch.Tensor,
tensor_sf: torch.Tensor,
global_scale: torch.Tensor | float,
dtype: torch.dtype,
block_size: int = 16,
swizzle: bool | None = True,
): ):
"""Dequantize the fp4 tensor back to high precision.""" """Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8. # Two fp4 values are packed into one uint8.
...@@ -57,8 +66,10 @@ def dequantize_to_dtype( ...@@ -57,8 +66,10 @@ def dequantize_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32) tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn) tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
if swizzle:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale
# scale the tensor # scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
...@@ -67,7 +78,8 @@ def dequantize_to_dtype( ...@@ -67,7 +78,8 @@ def dequantize_to_dtype(
def get_reciprocal(x): def get_reciprocal(x):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) # torch.where yields operation not permitted when stream is capturing.
return 1.0 / (x + (x == 0) * 1e8)
elif isinstance(x, (float, int)): elif isinstance(x, (float, int)):
return 0.0 if x == 0 else 1.0 / x return 0.0 if x == 0 else 1.0 / x
else: else:
...@@ -94,7 +106,7 @@ def ref_nvfp4_quant(x, global_scale, block_size): ...@@ -94,7 +106,7 @@ def ref_nvfp4_quant(x, global_scale, block_size):
m, n = x.shape m, n = x.shape
x = torch.reshape(x, (m, n // block_size, block_size)) x = torch.reshape(x, (m, n // block_size, block_size))
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = global_scale * (vec_max * FLOAT4_E2M1_MAX_RECIPROCAL)
scale = torch.clamp(scale, max=448, min=-448) scale = torch.clamp(scale, max=448, min=-448)
scale = scale.to(torch.float8_e4m3fn).to(torch.float32) scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
...@@ -111,6 +123,7 @@ def run_nvfp4_emulations( ...@@ -111,6 +123,7 @@ def run_nvfp4_emulations(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale_swizzled: torch.Tensor, weight_scale_swizzled: torch.Tensor,
weight_global_scale: torch.Tensor, weight_global_scale: torch.Tensor,
swizzle: bool | None = True,
): ):
group_size = 16 group_size = 16
x_m, x_k = x.shape x_m, x_k = x.shape
...@@ -132,8 +145,8 @@ def run_nvfp4_emulations( ...@@ -132,8 +145,8 @@ def run_nvfp4_emulations(
weight_scale_swizzled.data, weight_scale_swizzled.data,
weight_global_scale, weight_global_scale,
output_dtype, output_dtype,
x.device,
group_size, group_size,
swizzle=swizzle,
) )
# matmul # matmul
......
...@@ -17,31 +17,99 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( ...@@ -17,31 +17,99 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_fp4_layer_for_marlin, prepare_fp4_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
run_nvfp4_emulations, run_nvfp4_emulations,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.import_utils import has_fbgemm_gpu
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
# NOTE: This is ordered by preferred backend.
# Example: if both are available, FLASHINFER_CUTLASS is preferred to VLLM_CUTLASS.
class NvFp4LinearBackend(Enum): class NvFp4LinearBackend(Enum):
VLLM_CUTLASS = "cutlass"
FLASHINFER_CUTLASS = "flashinfer-cutlass" FLASHINFER_CUTLASS = "flashinfer-cutlass"
VLLM_CUTLASS = "cutlass"
MARLIN = "marlin"
FLASHINFER_TRTLLM = "flashinfer-trtllm" FLASHINFER_TRTLLM = "flashinfer-trtllm"
FLASHINFER_CUDNN = "flashinfer-cudnn" FLASHINFER_CUDNN = "flashinfer-cudnn"
FBGEMM = "fbgemm" FBGEMM = "fbgemm"
MARLIN = "marlin"
EMULATION = "emulation" EMULATION = "emulation"
NVFP4_LINEAR_BACKENDS = list(NvFp4LinearBackend)
def is_backend_supported(backend: NvFp4LinearBackend) -> tuple[bool, str | None]:
reason = None
supported = True
if backend == NvFp4LinearBackend.FLASHINFER_CUTLASS:
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
# quantization and GEMM) were compiled for the current SM version.
# FlashInfer backends still rely on the vLLM quantization kernels,
# so we gate them on the same check.
supported = (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
)
if not supported:
reason = "FlashInfer is required, >=sm_100 is required"
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
supported = cutlass_fp4_supported()
if not supported:
reason = "Cutlass is required"
elif backend == NvFp4LinearBackend.MARLIN:
supported = is_fp4_marlin_supported()
if not supported:
reason = "Marlin is required"
elif backend in [
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
]:
supported = has_flashinfer()
if not supported:
reason = "FlashInfer is required"
elif backend == NvFp4LinearBackend.FBGEMM:
supported = has_fbgemm_gpu()
if not supported:
reason = "fbgemm_gpu is required"
elif backend == NvFp4LinearBackend.EMULATION:
# e.g. AMD Instinct does not support native NVFP4.
unsupported_reasons = {}
for other_backend in NVFP4_LINEAR_BACKENDS:
if other_backend == NvFp4LinearBackend.EMULATION:
continue
other_supported, other_reason = is_backend_supported(other_backend)
if not other_supported:
unsupported_reasons[other_backend] = other_reason
if unsupported_reasons:
unsupported_reasons_str = "\n - ".join(
[f"{b.value}: {r}" for b, r in unsupported_reasons.items()]
)
logger.warning_once(
f"NVFP4 linear falling back to the slow and unoptimized "
f"backend=NvFp4LinearBackend.EMULATION as no optimized backend is "
f"available (unavailable reasons:\n - {unsupported_reasons_str}\n). "
"In case you expect one of these backend to be used, "
"please verify your environment."
)
return supported, reason
def select_nvfp4_linear_backend() -> NvFp4LinearBackend: def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
""" """
Select the best available NVFP4 GEMM backend based on environment Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities. configuration and platform capabilities.
""" """
backend: NvFp4LinearBackend | None = None selected_backend: NvFp4LinearBackend | None = None
if envs.VLLM_USE_FBGEMM: if envs.VLLM_USE_FBGEMM:
try: try:
...@@ -51,51 +119,36 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend: ...@@ -51,51 +119,36 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend:
"Backend fbgemm requires fbgemm.f4f4bf16 operator, " "Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai" "Please install with: pip install fbgemm-gpu-genai"
) from exc ) from exc
backend = NvFp4LinearBackend.FBGEMM selected_backend = NvFp4LinearBackend.FBGEMM
elif envs.VLLM_USE_NVFP4_CT_EMULATIONS: elif envs.VLLM_USE_NVFP4_CT_EMULATIONS:
backend = NvFp4LinearBackend.EMULATION selected_backend = NvFp4LinearBackend.EMULATION
elif envs.VLLM_NVFP4_GEMM_BACKEND is None: elif envs.VLLM_NVFP4_GEMM_BACKEND is None:
# Auto-select best available backend. for backend in NVFP4_LINEAR_BACKENDS:
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both supported, reason = is_backend_supported(backend)
# quantization and GEMM) were compiled for the current SM version. if supported:
# FlashInfer backends still rely on the vLLM quantization kernels, selected_backend = backend
# so we gate them on the same check. break
if (
cutlass_fp4_supported()
and current_platform.has_device_capability(100)
and has_flashinfer()
):
backend = NvFp4LinearBackend.FLASHINFER_CUTLASS
elif cutlass_fp4_supported():
backend = NvFp4LinearBackend.VLLM_CUTLASS
elif is_fp4_marlin_supported():
backend = NvFp4LinearBackend.MARLIN
else: else:
backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND) selected_backend = NvFp4LinearBackend(envs.VLLM_NVFP4_GEMM_BACKEND)
# Validate that the backend is supported if selected_backend is None:
if backend in (
NvFp4LinearBackend.FLASHINFER_CUTLASS,
NvFp4LinearBackend.FLASHINFER_TRTLLM,
NvFp4LinearBackend.FLASHINFER_CUDNN,
):
assert has_flashinfer(), f"FlashInfer is required for {backend}"
assert cutlass_fp4_supported(), (
f"{backend} requires vLLM NVFP4 quantization kernels compiled "
f"for the current GPU (SM {current_platform.get_device_capability()})"
)
elif backend == NvFp4LinearBackend.VLLM_CUTLASS:
assert cutlass_fp4_supported(), f"Cutlass is required for {backend}"
elif backend == NvFp4LinearBackend.MARLIN:
assert is_fp4_marlin_supported(), f"Marlin is required for {backend}"
elif backend is None:
raise ValueError( raise ValueError(
f"No NVFP4 GEMM backend selected, " f"No NVFP4 GEMM backend selected, "
f"available backends: {list(NvFp4LinearBackend)}" f"available backends: {NVFP4_LINEAR_BACKENDS}"
)
supported, reason = is_backend_supported(selected_backend)
if not supported:
raise ValueError(
f"The selected backend={selected_backend} is not supported in current "
f"environment. Reason: {reason}. Current environment: "
f"{envs.VLLM_USE_FBGEMM=}, {envs.VLLM_USE_NVFP4_CT_EMULATIONS=}, "
f"{envs.VLLM_NVFP4_GEMM_BACKEND}."
) )
logger.info_once(f"Using {backend} for NVFP4 GEMM") logger.info_once(f"Using {selected_backend} for NVFP4 GEMM")
return backend return selected_backend
def prepare_weights_for_nvfp4_flashinfer_trtllm( def prepare_weights_for_nvfp4_flashinfer_trtllm(
...@@ -183,6 +236,10 @@ def convert_to_nvfp4_linear_kernel_format( ...@@ -183,6 +236,10 @@ def convert_to_nvfp4_linear_kernel_format(
layer.weight = torch.nn.Parameter(weight, requires_grad=False) layer.weight = torch.nn.Parameter(weight, requires_grad=False)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
layer.weights_padding_cols = weights_padding_cols layer.weights_padding_cols = weights_padding_cols
elif backend == NvFp4LinearBackend.EMULATION:
# We can not call `.to(device)` during cuda graph capture - do it here instead.
# (operation not permitted when stream is capturing)
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(layer.weight.device)
def apply_nvfp4_linear( def apply_nvfp4_linear(
...@@ -190,6 +247,7 @@ def apply_nvfp4_linear( ...@@ -190,6 +247,7 @@ def apply_nvfp4_linear(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
swizzle: bool | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Apply NVFP4 linear transformation using the specified backend. Apply NVFP4 linear transformation using the specified backend.
...@@ -220,6 +278,7 @@ def apply_nvfp4_linear( ...@@ -220,6 +278,7 @@ def apply_nvfp4_linear(
weight=weight, weight=weight,
weight_scale_swizzled=weight_scale, weight_scale_swizzled=weight_scale,
weight_global_scale=weight_global_scale, weight_global_scale=weight_global_scale,
swizzle=swizzle,
) )
if bias is not None: if bias is not None:
out = out + bias out = out + bias
......
...@@ -409,6 +409,7 @@ class RocmPlatform(Platform): ...@@ -409,6 +409,7 @@ class RocmPlatform(Platform):
"mxfp4", "mxfp4",
"torchao", "torchao",
"bitsandbytes", "bitsandbytes",
"modelopt_fp4",
] ]
@classmethod @classmethod
......
...@@ -461,3 +461,8 @@ def has_aiter() -> bool: ...@@ -461,3 +461,8 @@ def has_aiter() -> bool:
def has_mori() -> bool: def has_mori() -> bool:
"""Whether the optional `mori` package is available.""" """Whether the optional `mori` package is available."""
return _has_module("mori") return _has_module("mori")
def has_fbgemm_gpu() -> bool:
"""Whether the optional `fbgemm_gpu` package is available."""
return _has_module("fbgemm_gpu")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment