Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2800706f
Unverified
Commit
2800706f
authored
Apr 09, 2026
by
Michael Goin
Committed by
GitHub
Apr 09, 2026
Browse files
[Refactor] Move NVFP4 GEMM management into NvFp4LinearKernel (#39129)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
0d310ffb
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
697 additions
and
377 deletions
+697
-377
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+133
-0
vllm/model_executor/kernels/linear/nvfp4/__init__.py
vllm/model_executor/kernels/linear/nvfp4/__init__.py
+12
-0
vllm/model_executor/kernels/linear/nvfp4/base.py
vllm/model_executor/kernels/linear/nvfp4/base.py
+68
-0
vllm/model_executor/kernels/linear/nvfp4/cutlass.py
vllm/model_executor/kernels/linear/nvfp4/cutlass.py
+80
-0
vllm/model_executor/kernels/linear/nvfp4/emulation.py
vllm/model_executor/kernels/linear/nvfp4/emulation.py
+49
-0
vllm/model_executor/kernels/linear/nvfp4/fbgemm.py
vllm/model_executor/kernels/linear/nvfp4/fbgemm.py
+69
-0
vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
+218
-0
vllm/model_executor/kernels/linear/nvfp4/marlin.py
vllm/model_executor/kernels/linear/nvfp4/marlin.py
+57
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+4
-19
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+7
-20
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
+0
-338
No files found.
vllm/model_executor/kernels/linear/__init__.py
View file @
2800706f
...
@@ -55,6 +55,27 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
...
@@ -55,6 +55,27 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel
,
XPUW4A8IntLinearKernel
,
XPUwNa16LinearKernel
,
XPUwNa16LinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.nvfp4
import
(
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
,
)
from
vllm.model_executor.kernels.linear.nvfp4.cutlass
import
(
CutlassNvFp4LinearKernel
,
)
from
vllm.model_executor.kernels.linear.nvfp4.emulation
import
(
EmulationNvFp4LinearKernel
,
)
from
vllm.model_executor.kernels.linear.nvfp4.fbgemm
import
(
FbgemmNvFp4LinearKernel
,
)
from
vllm.model_executor.kernels.linear.nvfp4.flashinfer
import
(
FlashInferCudnnNvFp4LinearKernel
,
FlashInferCutlassNvFp4LinearKernel
,
FlashInferTrtllmNvFp4LinearKernel
,
)
from
vllm.model_executor.kernels.linear.nvfp4.marlin
import
(
MarlinNvFp4LinearKernel
,
)
from
vllm.model_executor.kernels.linear.scaled_mm
import
(
from
vllm.model_executor.kernels.linear.scaled_mm
import
(
Fp8BlockScaledMMLinearKernel
,
Fp8BlockScaledMMLinearKernel
,
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearKernel
,
...
@@ -180,6 +201,22 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
...
@@ -180,6 +201,22 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
],
],
}
}
# in priority/performance order (when available)
_POSSIBLE_NVFP4_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
NvFp4LinearKernel
]]]
=
{
PlatformEnum
.
CUDA
:
[
FlashInferCutlassNvFp4LinearKernel
,
CutlassNvFp4LinearKernel
,
MarlinNvFp4LinearKernel
,
FlashInferTrtllmNvFp4LinearKernel
,
FlashInferCudnnNvFp4LinearKernel
,
FbgemmNvFp4LinearKernel
,
EmulationNvFp4LinearKernel
,
],
PlatformEnum
.
ROCM
:
[
EmulationNvFp4LinearKernel
,
],
}
# TODO make all kernels inherit from MMLinearKernel
# TODO make all kernels inherit from MMLinearKernel
# then bound _KernelT only to MMLinearKernel
# then bound _KernelT only to MMLinearKernel
_KernelT
=
TypeVar
(
"_KernelT"
,
bound
=
ScaledMMLinearKernel
|
MMLinearKernel
)
_KernelT
=
TypeVar
(
"_KernelT"
,
bound
=
ScaledMMLinearKernel
|
MMLinearKernel
)
...
@@ -426,6 +463,88 @@ def choose_mp_linear_kernel(
...
@@ -426,6 +463,88 @@ def choose_mp_linear_kernel(
)
)
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL
:
dict
[
str
,
type
[
NvFp4LinearKernel
]]
=
{
"flashinfer-cutlass"
:
FlashInferCutlassNvFp4LinearKernel
,
"cutlass"
:
CutlassNvFp4LinearKernel
,
"marlin"
:
MarlinNvFp4LinearKernel
,
"flashinfer-trtllm"
:
FlashInferTrtllmNvFp4LinearKernel
,
"flashinfer-cudnn"
:
FlashInferCudnnNvFp4LinearKernel
,
"emulation"
:
EmulationNvFp4LinearKernel
,
}
def
init_nvfp4_linear_kernel
()
->
NvFp4LinearKernel
:
"""Select and instantiate the best NVFP4 linear kernel for the
current platform."""
config
=
NvFp4LinearLayerConfig
()
# Env-var overrides.
force_kernel
:
type
[
NvFp4LinearKernel
]
|
None
=
None
if
envs
.
VLLM_USE_FBGEMM
:
force_kernel
=
FbgemmNvFp4LinearKernel
elif
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
force_kernel
=
EmulationNvFp4LinearKernel
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
is
not
None
:
backend_name
=
envs
.
VLLM_NVFP4_GEMM_BACKEND
force_kernel
=
_NVFP4_BACKEND_TO_KERNEL
.
get
(
backend_name
)
if
force_kernel
is
None
:
raise
ValueError
(
f
"Unknown VLLM_NVFP4_GEMM_BACKEND=
{
backend_name
!
r
}
. "
f
"Valid choices:
{
list
(
_NVFP4_BACKEND_TO_KERNEL
.
keys
())
}
"
)
if
force_kernel
is
not
None
:
is_supported
,
reason
=
force_kernel
.
is_supported
()
if
not
is_supported
:
raise
ValueError
(
f
"Forced NVFP4 kernel
{
force_kernel
.
__name__
}
is not "
f
"supported:
{
reason
}
"
)
logger
.
info_once
(
"Using %s for NVFP4 GEMM"
,
force_kernel
.
__name__
)
return
force_kernel
(
config
)
# Auto-select from registry.
platform
=
current_platform
.
_enum
possible
=
_POSSIBLE_NVFP4_KERNELS
.
get
(
platform
,
[])
failure_reasons
=
[]
for
kernel_cls
in
possible
:
if
kernel_cls
.
__name__
in
envs
.
VLLM_DISABLED_KERNELS
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
disabled by environment variable"
)
continue
is_supported
,
reason
=
kernel_cls
.
is_supported
()
if
not
is_supported
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
:
{
reason
}
"
)
continue
can_implement
,
reason
=
kernel_cls
.
can_implement
(
config
)
if
not
can_implement
:
failure_reasons
.
append
(
f
"
{
kernel_cls
.
__name__
}
:
{
reason
}
"
)
continue
if
kernel_cls
is
EmulationNvFp4LinearKernel
and
failure_reasons
:
logger
.
warning_once
(
"NVFP4 linear falling back to the slow and unoptimized "
"emulation backend as no optimized backend is available "
"(unavailable reasons:
\n
- %s
\n
). "
"In case you expect one of these backends to be used, "
"please verify your environment."
,
"
\n
- "
.
join
(
failure_reasons
),
)
logger
.
info_once
(
"Using %s for NVFP4 GEMM"
,
kernel_cls
.
__name__
)
return
kernel_cls
(
config
)
raise
ValueError
(
"Failed to find a kernel that can implement the "
"NVFP4 linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reasons
)
)
def
register_linear_kernel
(
def
register_linear_kernel
(
kernel_class
:
type
,
kernel_class
:
type
,
platform
:
PlatformEnum
,
platform
:
PlatformEnum
,
...
@@ -455,6 +574,10 @@ def register_linear_kernel(
...
@@ -455,6 +574,10 @@ def register_linear_kernel(
if
platform
not
in
_POSSIBLE_FP8_KERNELS
:
if
platform
not
in
_POSSIBLE_FP8_KERNELS
:
_POSSIBLE_FP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_FP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_FP8_KERNELS
[
platform
].
append
(
kernel_class
)
_POSSIBLE_FP8_KERNELS
[
platform
].
append
(
kernel_class
)
elif
kernel_type
==
"nvfp4"
:
if
platform
not
in
_POSSIBLE_NVFP4_KERNELS
:
_POSSIBLE_NVFP4_KERNELS
[
platform
]
=
[]
_POSSIBLE_NVFP4_KERNELS
[
platform
].
append
(
kernel_class
)
else
:
else
:
raise
ValueError
(
f
"Unrecognized kernel type:
{
kernel_type
}
"
)
raise
ValueError
(
f
"Unrecognized kernel type:
{
kernel_type
}
"
)
...
@@ -462,6 +585,7 @@ def register_linear_kernel(
...
@@ -462,6 +585,7 @@ def register_linear_kernel(
__all__
=
[
__all__
=
[
"init_fp8_linear_kernel"
,
"init_fp8_linear_kernel"
,
"init_int8_linear_kernel"
,
"init_int8_linear_kernel"
,
"init_nvfp4_linear_kernel"
,
"choose_mp_linear_kernel"
,
"choose_mp_linear_kernel"
,
"register_linear_kernel"
,
"register_linear_kernel"
,
"FP8ScaledMMLinearKernel"
,
"FP8ScaledMMLinearKernel"
,
...
@@ -470,6 +594,8 @@ __all__ = [
...
@@ -470,6 +594,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig"
,
"FP8ScaledMMLinearLayerConfig"
,
"Int8ScaledMMLinearLayerConfig"
,
"Int8ScaledMMLinearLayerConfig"
,
"ScaledMMLinearLayerConfig"
,
"ScaledMMLinearLayerConfig"
,
"NvFp4LinearKernel"
,
"NvFp4LinearLayerConfig"
,
"AiterInt8ScaledMMLinearKernel"
,
"AiterInt8ScaledMMLinearKernel"
,
"CPUInt8ScaledMMLinearKernel"
,
"CPUInt8ScaledMMLinearKernel"
,
"CutlassFP8ScaledMMLinearKernel"
,
"CutlassFP8ScaledMMLinearKernel"
,
...
@@ -492,6 +618,13 @@ __all__ = [
...
@@ -492,6 +618,13 @@ __all__ = [
"MarlinLinearKernel"
,
"MarlinLinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUW4A8IntLinearKernel"
,
"XPUwNa16LinearKernel"
,
"XPUwNa16LinearKernel"
,
"CutlassNvFp4LinearKernel"
,
"EmulationNvFp4LinearKernel"
,
"FbgemmNvFp4LinearKernel"
,
"FlashInferCutlassNvFp4LinearKernel"
,
"FlashInferTrtllmNvFp4LinearKernel"
,
"FlashInferCudnnNvFp4LinearKernel"
,
"MarlinNvFp4LinearKernel"
,
"_KernelT"
,
"_KernelT"
,
"DeepGemmFp8BlockScaledMMKernel"
,
"DeepGemmFp8BlockScaledMMKernel"
,
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel"
,
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel"
,
...
...
vllm/model_executor/kernels/linear/nvfp4/__init__.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.model_executor.kernels.linear.nvfp4.base
import
(
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
,
)
__all__
=
[
"NvFp4LinearKernel"
,
"NvFp4LinearLayerConfig"
,
]
vllm/model_executor/kernels/linear/nvfp4/base.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
import
torch
@
dataclass
class
NvFp4LinearLayerConfig
:
"""Configuration for an NVFP4 linear layer.
All NVFP4 layers share the same structure: packed uint8 weights (2 FP4 values per
byte), FP8-E4M3 per-block weight scales (group size 16), and scalar global
scales for both weights and activations.
"""
pass
class
NvFp4LinearKernel
(
ABC
):
"""Base class for NVFP4 quantized linear kernels.
Each subclass implements a specific GEMM backend (CUTLASS, Marlin, etc).
The kernel selection mechanism iterates over registered subclasses in
priority order,calling ``is_supported`` and ``can_implement`` to find the best
match for the current hardware.
"""
def
__init__
(
self
,
config
:
NvFp4LinearLayerConfig
)
->
None
:
assert
self
.
can_implement
(
config
)[
0
]
assert
self
.
is_supported
()[
0
]
self
.
config
=
config
@
classmethod
@
abstractmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
"""Return whether this kernel can run on the current platform."""
raise
NotImplementedError
@
classmethod
@
abstractmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
"""Return whether this kernel can handle *config*."""
raise
NotImplementedError
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Transform weights into the format required by this kernel.
Called once after checkpoint weights have been loaded onto the
device. Implementations should repack / swizzle / pad weights
and scales in-place on *layer*.
"""
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Run the quantized GEMM."""
raise
NotImplementedError
vllm/model_executor/kernels/linear/nvfp4/cutlass.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
,
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
cutlass_fp4_supported
,
pad_nvfp4_activation_for_cutlass
,
pad_nvfp4_weight_for_cutlass
,
slice_nvfp4_output
,
swizzle_blockscale
,
)
from
.base
import
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
class
CutlassNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 GEMM via the vLLM CUTLASS kernel."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
cutlass_fp4_supported
():
return
True
,
None
return
False
,
"CUTLASS FP4 kernels not available"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
weight_scale
.
data
),
requires_grad
=
False
)
padded_weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
padded_weight
,
requires_grad
=
False
)
layer
.
weights_padding_cols
=
weights_padding_cols
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
output_size
=
layer
.
output_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
"cutlass"
,
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
getattr
(
layer
,
"weights_padding_cols"
,
0
)
)
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
)
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
vllm/model_executor/kernels/linear/nvfp4/emulation.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
kE2M1ToFloat_handle
,
run_nvfp4_emulations
,
)
from
.base
import
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
class
EmulationNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""Software emulation fallback for NVFP4 (dequant → BF16 matmul)."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
# Always available as a last-resort fallback.
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Move the E2M1 lookup table to the device now, because
# `.to(device)` is not allowed during CUDA graph capture.
kE2M1ToFloat_handle
.
val
=
kE2M1ToFloat_handle
.
val
.
to
(
layer
.
weight
.
device
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
out
=
run_nvfp4_emulations
(
x
=
x
,
input_global_scale
=
layer
.
input_global_scale_inv
,
weight
=
layer
.
weight
,
weight_scale_swizzled
=
layer
.
weight_scale
,
weight_global_scale
=
layer
.
weight_global_scale
,
swizzle
=
False
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
vllm/model_executor/kernels/linear/nvfp4/fbgemm.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm._custom_ops
import
scaled_fp4_quant
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
slice_nvfp4_output
,
swizzle_blockscale
,
)
from
vllm.utils.import_utils
import
has_fbgemm_gpu
from
.base
import
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
class
FbgemmNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 GEMM via FBGEMM."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
has_fbgemm_gpu
():
return
True
,
None
return
False
,
"fbgemm_gpu required"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
swizzled
=
swizzle_blockscale
(
layer
.
weight_scale
.
data
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
swizzled
.
view
(
-
1
).
view
(
torch
.
uint8
),
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
import
fbgemm_gpu
# noqa: F401 - registers torch.ops.fbgemm.*
output_size
=
layer
.
output_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
"fbgemm"
,
)
out
=
torch
.
ops
.
fbgemm
.
f4f4bf16
(
x_fp4
,
layer
.
weight
,
x_blockscale
.
view
(
-
1
).
view
(
torch
.
uint8
),
layer
.
weight_scale
,
layer
.
alpha
,
use_mx
=
False
,
).
to
(
output_dtype
)
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm._custom_ops
import
scaled_fp4_quant
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
pad_nvfp4_activation_for_cutlass
,
pad_nvfp4_weight_for_cutlass
,
slice_nvfp4_output
,
swizzle_blockscale
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp4_mm
,
has_flashinfer
from
.base
import
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
class
FlashInferCutlassNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 GEMM via FlashInfer's CUTLASS wrapper."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
cutlass_fp4_supported
,
)
if
(
cutlass_fp4_supported
()
and
current_platform
.
has_device_capability
(
100
)
and
has_flashinfer
()
):
return
True
,
None
return
False
,
"FlashInfer + >=sm_100 required"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
weight_scale
.
data
),
requires_grad
=
False
)
padded_weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
padded_weight
,
requires_grad
=
False
)
layer
.
weights_padding_cols
=
weights_padding_cols
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
output_size
=
layer
.
output_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
"flashinfer-cutlass"
,
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
getattr
(
layer
,
"weights_padding_cols"
,
0
)
)
out
=
flashinfer_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
backend
=
"cutlass"
,
)
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
class
FlashInferTrtllmNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 GEMM via FlashInfer's TensorRT-LLM wrapper."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
has_flashinfer
():
return
True
,
None
return
False
,
"FlashInfer required"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
weight
=
layer
.
weight
.
data
weight_scale
=
layer
.
weight_scale
.
data
epilogue_tile_m
=
128
layer
.
weight
=
torch
.
nn
.
Parameter
(
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
),
requires_grad
=
False
,
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
shuffle_matrix_sf_a
(
weight_scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
weight_scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
),
requires_grad
=
False
,
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
output_size
=
layer
.
output_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
"flashinfer-trtllm"
,
)
out
=
flashinfer_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
backend
=
"trtllm"
,
)
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
class
FlashInferCudnnNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 GEMM via FlashInfer's cuDNN wrapper."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
has_flashinfer
():
return
True
,
None
return
False
,
"FlashInfer required"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# cuDNN uses the same swizzled + padded layout as CUTLASS
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
swizzle_blockscale
(
layer
.
weight_scale
.
data
),
requires_grad
=
False
)
padded_weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
padded_weight
,
requires_grad
=
False
)
layer
.
weights_padding_cols
=
weights_padding_cols
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
output_size
=
layer
.
output_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
"flashinfer-cudnn"
,
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
getattr
(
layer
,
"weights_padding_cols"
,
0
)
)
out
=
flashinfer_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
backend
=
"cudnn"
,
)
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
vllm/model_executor/kernels/linear/nvfp4/marlin.py
0 → 100644
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
prepare_fp4_layer_for_marlin
,
)
from
.base
import
NvFp4LinearKernel
,
NvFp4LinearLayerConfig
logger
=
init_logger
(
__name__
)
class
MarlinNvFp4LinearKernel
(
NvFp4LinearKernel
):
"""NVFP4 weight-only GEMM via Marlin (W4A16)."""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
is_fp4_marlin_supported
():
return
True
,
None
return
False
,
"Marlin FP4 not available"
@
classmethod
def
can_implement
(
cls
,
config
:
NvFp4LinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
logger
.
warning_once
(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_fp4_layer_for_marlin
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
return
apply_fp4_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_global_scale
=
layer
.
weight_global_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
2800706f
...
@@ -6,15 +6,10 @@ import torch
...
@@ -6,15 +6,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.kernels.linear
import
init_nvfp4_linear_kernel
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
NvFp4LinearBackend
,
apply_nvfp4_linear
,
convert_to_nvfp4_linear_kernel_format
,
select_nvfp4_linear_backend
,
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
...
@@ -29,13 +24,9 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
...
@@ -29,13 +24,9 @@ __all__ = ["CompressedTensorsW4A4Fp4"]
class
CompressedTensorsW4A4Fp4
(
CompressedTensorsScheme
):
class
CompressedTensorsW4A4Fp4
(
CompressedTensorsScheme
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
backend
=
selec
t_nvfp4_linear_
backend
()
self
.
kernel
=
ini
t_nvfp4_linear_
kernel
()
self
.
group_size
=
16
self
.
group_size
=
16
self
.
swizzle
=
None
if
self
.
backend
==
NvFp4LinearBackend
.
EMULATION
:
self
.
swizzle
=
False
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
75
return
75
...
@@ -130,7 +121,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -130,7 +121,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
)
)
# Convert layer to NVFP4 linear kernel format
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format
(
self
.
backend
,
layer
)
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
...
@@ -138,10 +129,4 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -138,10 +129,4 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
apply_nvfp4_linear
(
return
self
.
kernel
.
apply_weights
(
layer
=
layer
,
x
=
x
,
bias
=
bias
)
backend
=
self
.
backend
,
layer
=
layer
,
x
=
x
,
bias
=
bias
,
swizzle
=
self
.
swizzle
,
)
vllm/model_executor/layers/quantization/modelopt.py
View file @
2800706f
...
@@ -10,7 +10,10 @@ from torch.nn.parameter import Parameter
...
@@ -10,7 +10,10 @@ from torch.nn.parameter import Parameter
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.kernels.linear
import
init_fp8_linear_kernel
from
vllm.model_executor.kernels.linear
import
(
init_fp8_linear_kernel
,
init_nvfp4_linear_kernel
,
)
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.attention
import
Attention
,
MLAAttention
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -70,12 +73,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
...
@@ -70,12 +73,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
Mxfp8LinearOp
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
NvFp4LinearBackend
,
apply_nvfp4_linear
,
convert_to_nvfp4_linear_kernel_format
,
select_nvfp4_linear_backend
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
create_fp8_quant_key
,
create_fp8_quant_key
,
...
@@ -1090,11 +1087,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1090,11 +1087,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
)
->
None
:
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
marlin_input_dtype
=
None
self
.
marlin_input_dtype
=
None
self
.
backend
=
select_nvfp4_linear_backend
()
self
.
kernel
=
init_nvfp4_linear_kernel
()
self
.
swizzle
=
None
if
self
.
backend
==
NvFp4LinearBackend
.
EMULATION
:
self
.
swizzle
=
False
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -1201,7 +1194,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1201,7 +1194,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
)
# Convert layer to NVFP4 linear kernel format
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format
(
self
.
backend
,
layer
)
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1209,13 +1202,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1209,13 +1202,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
apply_nvfp4_linear
(
return
self
.
kernel
.
apply_weights
(
layer
=
layer
,
x
=
x
,
bias
=
bias
)
backend
=
self
.
backend
,
layer
=
layer
,
x
=
x
,
bias
=
bias
,
swizzle
=
self
.
swizzle
,
)
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
View file @
2800706f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm._custom_ops
import
(
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
cutlass_scaled_mm_supports_fp4
,
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
prepare_fp4_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
kE2M1ToFloat_handle
,
run_nvfp4_emulations
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp4_mm
,
has_flashinfer
from
vllm.utils.import_utils
import
has_fbgemm_gpu
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
# NOTE: This is ordered by preferred backend.
# Example: if both are available, FLASHINFER_CUTLASS is preferred to VLLM_CUTLASS.
class
NvFp4LinearBackend
(
Enum
):
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
VLLM_CUTLASS
=
"cutlass"
MARLIN
=
"marlin"
FLASHINFER_TRTLLM
=
"flashinfer-trtllm"
FLASHINFER_CUDNN
=
"flashinfer-cudnn"
FBGEMM
=
"fbgemm"
EMULATION
=
"emulation"
NVFP4_LINEAR_BACKENDS
=
list
(
NvFp4LinearBackend
)
def
is_backend_supported
(
backend
:
NvFp4LinearBackend
)
->
tuple
[
bool
,
str
|
None
]:
reason
=
None
supported
=
True
if
backend
==
NvFp4LinearBackend
.
FLASHINFER_CUTLASS
:
# cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both
# quantization and GEMM) were compiled for the current SM version.
# FlashInfer backends still rely on the vLLM quantization kernels,
# so we gate them on the same check.
supported
=
(
cutlass_fp4_supported
()
and
current_platform
.
has_device_capability
(
100
)
and
has_flashinfer
()
)
if
not
supported
:
reason
=
"FlashInfer is required, >=sm_100 is required"
elif
backend
==
NvFp4LinearBackend
.
VLLM_CUTLASS
:
supported
=
cutlass_fp4_supported
()
if
not
supported
:
reason
=
"Cutlass is required"
elif
backend
==
NvFp4LinearBackend
.
MARLIN
:
supported
=
is_fp4_marlin_supported
()
if
not
supported
:
reason
=
"Marlin is required"
elif
backend
in
[
NvFp4LinearBackend
.
FLASHINFER_TRTLLM
,
NvFp4LinearBackend
.
FLASHINFER_CUDNN
,
]:
supported
=
has_flashinfer
()
if
not
supported
:
reason
=
"FlashInfer is required"
elif
backend
==
NvFp4LinearBackend
.
FBGEMM
:
supported
=
has_fbgemm_gpu
()
if
not
supported
:
reason
=
"fbgemm_gpu is required"
elif
backend
==
NvFp4LinearBackend
.
EMULATION
:
# e.g. AMD Instinct does not support native NVFP4.
unsupported_reasons
=
{}
for
other_backend
in
NVFP4_LINEAR_BACKENDS
:
if
other_backend
==
NvFp4LinearBackend
.
EMULATION
:
continue
other_supported
,
other_reason
=
is_backend_supported
(
other_backend
)
if
not
other_supported
:
unsupported_reasons
[
other_backend
]
=
other_reason
if
unsupported_reasons
:
unsupported_reasons_str
=
"
\n
- "
.
join
(
[
f
"
{
b
.
value
}
:
{
r
}
"
for
b
,
r
in
unsupported_reasons
.
items
()]
)
logger
.
warning_once
(
f
"NVFP4 linear falling back to the slow and unoptimized "
f
"backend=NvFp4LinearBackend.EMULATION as no optimized backend is "
f
"available (unavailable reasons:
\n
-
{
unsupported_reasons_str
}
\n
). "
"In case you expect one of these backend to be used, "
"please verify your environment."
)
return
supported
,
reason
def
select_nvfp4_linear_backend
()
->
NvFp4LinearBackend
:
"""
Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities.
"""
if
envs
.
VLLM_BATCH_INVARIANT
:
logger
.
info_once
(
"VLLM_BATCH_INVARIANT forces NVFP4 linear to use the emulation "
"backend for deterministic execution."
)
return
NvFp4LinearBackend
.
EMULATION
selected_backend
:
NvFp4LinearBackend
|
None
=
None
if
envs
.
VLLM_USE_FBGEMM
:
try
:
import
fbgemm_gpu
# noqa: F401
except
ImportError
as
exc
:
raise
ImportError
(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
)
from
exc
selected_backend
=
NvFp4LinearBackend
.
FBGEMM
elif
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
selected_backend
=
NvFp4LinearBackend
.
EMULATION
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
is
None
:
for
backend
in
NVFP4_LINEAR_BACKENDS
:
supported
,
reason
=
is_backend_supported
(
backend
)
if
supported
:
selected_backend
=
backend
break
else
:
selected_backend
=
NvFp4LinearBackend
(
envs
.
VLLM_NVFP4_GEMM_BACKEND
)
if
selected_backend
is
None
:
raise
ValueError
(
f
"No NVFP4 GEMM backend selected, "
f
"available backends:
{
NVFP4_LINEAR_BACKENDS
}
"
)
supported
,
reason
=
is_backend_supported
(
selected_backend
)
if
not
supported
:
raise
ValueError
(
f
"The selected backend=
{
selected_backend
}
is not supported in current "
f
"environment. Reason:
{
reason
}
. Current environment: "
f
"
{
envs
.
VLLM_USE_FBGEMM
=
}
,
{
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
=
}
, "
f
"
{
envs
.
VLLM_NVFP4_GEMM_BACKEND
}
."
)
logger
.
info_once
(
f
"Using
{
selected_backend
}
for NVFP4 GEMM"
)
return
selected_backend
def
prepare_weights_for_nvfp4_flashinfer_trtllm
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM."""
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
epilogue_tile_m
=
128
shuffled_weight
=
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
shuffled_weight_scale
=
(
shuffle_matrix_sf_a
(
weight_scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
weight_scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
)
)
return
shuffled_weight
,
shuffled_weight_scale
def
prepare_weights_for_nvfp4_cutlass
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
"""
Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM.
This involves padding weights for alignment (K and N divisible by 32)
"""
swizzled_weight_scale
=
swizzle_blockscale
(
weight_scale
)
padded_weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
weight
)
return
padded_weight
,
swizzled_weight_scale
,
weights_padding_cols
def
prepare_weights_for_nvfp4_fbgemm
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prepare weights and scales for FBGEMM FP4 GEMM."""
swizzled_weight_scale
=
swizzle_blockscale
(
weight_scale
)
swizzled_weight_scale
=
swizzled_weight_scale
.
view
(
-
1
).
view
(
torch
.
uint8
)
return
weight
,
swizzled_weight_scale
def
convert_to_nvfp4_linear_kernel_format
(
backend
:
NvFp4LinearBackend
,
layer
:
torch
.
nn
.
Module
,
)
->
None
:
"""Convert layer to NVFP4 linear kernel format."""
assert
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"Weight Block scale must be represented as FP8-E4M3"
)
# Default to no padding
layer
.
weights_padding_cols
=
0
if
backend
==
NvFp4LinearBackend
.
MARLIN
:
logger
.
warning_once
(
"Your GPU does not have native support for FP4 computation but "
"FP4 quantization is being used. Weight-only FP4 compression "
"will be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
prepare_fp4_layer_for_marlin
(
layer
)
elif
backend
==
NvFp4LinearBackend
.
FLASHINFER_TRTLLM
:
weight
,
weight_scale
=
prepare_weights_for_nvfp4_flashinfer_trtllm
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
elif
backend
==
NvFp4LinearBackend
.
FBGEMM
:
weight
,
weight_scale
=
prepare_weights_for_nvfp4_fbgemm
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
elif
backend
in
(
NvFp4LinearBackend
.
VLLM_CUTLASS
,
NvFp4LinearBackend
.
FLASHINFER_CUTLASS
,
NvFp4LinearBackend
.
FLASHINFER_CUDNN
,
):
weight
,
weight_scale
,
weights_padding_cols
=
prepare_weights_for_nvfp4_cutlass
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weights_padding_cols
=
weights_padding_cols
elif
backend
==
NvFp4LinearBackend
.
EMULATION
:
# We can not call `.to(device)` during cuda graph capture - do it here instead.
# (operation not permitted when stream is capturing)
kE2M1ToFloat_handle
.
val
=
kE2M1ToFloat_handle
.
val
.
to
(
layer
.
weight
.
device
)
def
apply_nvfp4_linear
(
backend
:
NvFp4LinearBackend
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
swizzle
:
bool
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Apply NVFP4 linear transformation using the specified backend.
"""
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_global_scale
=
layer
.
weight_global_scale
input_global_scale_inv
=
layer
.
input_global_scale_inv
alpha
=
layer
.
alpha
output_size
=
layer
.
output_size_per_partition
input_size
=
layer
.
input_size_per_partition
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
if
backend
==
NvFp4LinearBackend
.
MARLIN
:
return
apply_fp4_marlin_linear
(
input
=
x
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_global_scale
=
weight_global_scale
,
workspace
=
layer
.
workspace
,
size_n
=
output_size
,
size_k
=
input_size
,
bias
=
bias
,
)
elif
backend
==
NvFp4LinearBackend
.
EMULATION
:
x_2d
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
out
=
run_nvfp4_emulations
(
x
=
x_2d
,
input_global_scale
=
input_global_scale_inv
,
weight
=
weight
,
weight_scale_swizzled
=
weight_scale
,
weight_global_scale
=
weight_global_scale
,
swizzle
=
swizzle
,
)
out
=
out
[:,
:
output_size
]
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
backend
.
value
)
# Validate dtypes
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
weight
.
dtype
==
torch
.
uint8
assert
x_blockscale
.
dtype
==
torch
.
float8_e4m3fn
# weight_scale is fp8 for most backends, but uint8 for fbgemm
assert
weight_scale
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
uint8
)
assert
alpha
.
dtype
==
torch
.
float32
# Pad activations to match weight K-dimension padding
weights_padding_cols
=
getattr
(
layer
,
"weights_padding_cols"
,
0
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
weights_padding_cols
)
# Prepare args for the matmul
mm_args
=
(
x_fp4
,
weight
,
x_blockscale
,
weight_scale
,
alpha
,
output_dtype
,
)
# Call the appropriate backend
if
backend
.
value
.
startswith
(
"flashinfer-"
):
backend_name
=
backend
.
value
[
len
(
"flashinfer-"
)
:]
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
backend_name
)
elif
backend
==
NvFp4LinearBackend
.
FBGEMM
:
out
=
torch
.
ops
.
fbgemm
.
f4f4bf16
(
x_fp4
,
weight
,
x_blockscale
.
view
(
-
1
).
view
(
torch
.
uint8
),
weight_scale
,
alpha
,
use_mx
=
False
,
).
to
(
output_dtype
)
else
:
assert
backend
==
NvFp4LinearBackend
.
VLLM_CUTLASS
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
# Slice output to remove N-dimension padding
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
def
swizzle_blockscale
(
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
swizzle_blockscale
(
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment