Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
67ebaff5
Unverified
Commit
67ebaff5
authored
Jan 30, 2026
by
Michael Goin
Committed by
GitHub
Jan 30, 2026
Browse files
Refactor NVFP4 Linear utils for ModelOpt and CT (#33201)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
2b465570
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
462 additions
and
483 deletions
+462
-483
tests/kernels/moe/modular_kernel_tools/mk_objects.py
tests/kernels/moe/modular_kernel_tools/mk_objects.py
+1
-1
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-12
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py
...pressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
...pressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
+6
-21
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+29
-153
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+34
-162
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+3
-1
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
...el_executor/layers/quantization/utils/marlin_utils_fp4.py
+9
-7
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
...l_executor/layers/quantization/utils/nvfp4_moe_support.py
+1
-1
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
+375
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+0
-122
No files found.
tests/kernels/moe/modular_kernel_tools/mk_objects.py
View file @
67ebaff5
...
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
...
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe
import
(
TritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
,
)
)
from
vllm.model_executor.layers.quantization.utils.
quant
_utils
import
(
from
vllm.model_executor.layers.quantization.utils.
nvfp4
_utils
import
(
cutlass_fp4_supported
,
cutlass_fp4_supported
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
tests/quantization/test_compressed_tensors.py
View file @
67ebaff5
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
)
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.
quant
_utils
import
(
from
vllm.model_executor.layers.quantization.utils.
nvfp4
_utils
import
(
cutlass_fp4_supported
,
cutlass_fp4_supported
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
67ebaff5
...
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
...
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
)
)
from
compressed_tensors.transform
import
TransformConfig
from
compressed_tensors.transform
import
TransformConfig
import
vllm.envs
as
envs
from
vllm.distributed
import
(
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
...
@@ -63,9 +62,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
...
@@ -63,9 +62,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer
,
should_ignore_layer
,
)
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
cutlass_fp4_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -627,14 +623,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -627,14 +623,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
_is_nvfp4_format
(
weight_quant
)
and
self
.
_is_nvfp4_format
(
if
self
.
_is_nvfp4_format
(
weight_quant
)
and
self
.
_is_nvfp4_format
(
input_quant
input_quant
):
):
if
cutlass_fp4_supported
()
or
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
return
CompressedTensorsW4A4Fp4
()
return
CompressedTensorsW4A4Fp4
()
else
:
logger
.
warning_once
(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4."
)
return
CompressedTensorsW4A16Fp4
(
has_input_global_scale
=
True
)
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py
View file @
67ebaff5
...
@@ -81,7 +81,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
...
@@ -81,7 +81,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Rename weight_packed to weight that marlin expects
# Rename weight_packed to weight that marlin expects
layer
.
weight
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
del
layer
.
weight_packed
del
layer
.
weight_packed
...
@@ -98,7 +98,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
...
@@ -98,7 +98,7 @@ class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
_2
=
None
,
weight_
global_
scale
=
None
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
View file @
67ebaff5
...
@@ -22,8 +22,7 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
...
@@ -22,8 +22,7 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
class
CompressedTensorsW4A16Fp4
(
CompressedTensorsScheme
):
class
CompressedTensorsW4A16Fp4
(
CompressedTensorsScheme
):
def
__init__
(
self
,
has_input_global_scale
:
bool
=
False
):
def
__init__
(
self
):
self
.
has_input_global_scale
=
has_input_global_scale
self
.
group_size
=
16
self
.
group_size
=
16
@
classmethod
@
classmethod
...
@@ -79,30 +78,16 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
...
@@ -79,30 +78,16 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
if
self
.
has_input_global_scale
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
input_global_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Process parameters for marlin repacking
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
# Rename weight_packed to weight that marlin expects
layer
.
weight
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
del
layer
.
weight_packed
del
layer
.
weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# ct stores the inverse of what is expected by the marlin kernel
# Note: ct stores the inverse of what is expected by the marlin kernel
layer
.
weight_global_scale
=
Parameter
(
layer
.
weight_scale_2
=
Parameter
(
1.0
/
layer
.
weight_global_scale
.
max
().
to
(
torch
.
float32
),
requires_grad
=
False
1
/
layer
.
weight_global_scale
.
max
().
to
(
torch
.
float32
),
requires_grad
=
False
)
)
del
layer
.
weight_global_scale
if
self
.
has_input_global_scale
:
layer
.
input_global_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_global_scale
.
data
,
requires_grad
=
False
)
prepare_fp4_layer_for_marlin
(
layer
)
prepare_fp4_layer_for_marlin
(
layer
)
...
@@ -116,7 +101,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
...
@@ -116,7 +101,7 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
_2
=
layer
.
weight_scale
_2
,
weight_
global_
scale
=
layer
.
weight_
global_
scale
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
67ebaff5
...
@@ -5,75 +5,31 @@ from collections.abc import Callable
...
@@ -5,75 +5,31 @@ from collections.abc import Callable
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
run_nvfp4_emulations
,
apply_nvfp4_linear
,
)
convert_to_nvfp4_linear_kernel_format
,
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
select_nvfp4_linear_backend
,
cutlass_fp4_supported
,
pad_nvfp4_activation_for_cutlass
,
pad_nvfp4_weight_for_cutlass
,
slice_nvfp4_output
,
swizzle_blockscale
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
from
vllm.utils.flashinfer
import
(
flashinfer_scaled_fp4_mm
,
has_flashinfer
,
)
logger
=
init_logger
(
__name__
)
__all__
=
[
"CompressedTensorsW4A4Fp4"
]
__all__
=
[
"CompressedTensorsW4A4Fp4"
]
class
CompressedTensorsW4A4Fp4
(
CompressedTensorsScheme
):
class
CompressedTensorsW4A4Fp4
(
CompressedTensorsScheme
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
backend
=
"none"
self
.
backend
=
select_nvfp4_linear_backend
()
if
envs
.
VLLM_NVFP4_GEMM_BACKEND
is
None
:
if
has_flashinfer
():
self
.
backend
=
"flashinfer-cutlass"
elif
cutlass_fp4_supported
():
self
.
backend
=
"cutlass"
elif
envs
.
VLLM_USE_FBGEMM
:
self
.
backend
=
"fbgemm"
try
:
import
fbgemm_gpu
# noqa: F401
except
ImportError
as
exc
:
raise
ImportError
(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
)
from
exc
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
.
startswith
(
"flashinfer-"
):
self
.
backend
=
envs
.
VLLM_NVFP4_GEMM_BACKEND
assert
has_flashinfer
(),
f
"FlashInfer is required for
{
self
.
backend
}
"
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
==
"cutlass"
:
self
.
backend
=
"cutlass"
assert
cutlass_fp4_supported
(),
f
"Cutlass is required for
{
self
.
backend
}
"
if
self
.
backend
==
"none"
:
raise
ValueError
(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger
.
info_once
(
f
"Using
{
self
.
backend
}
for NVFP4 GEMM"
)
self
.
group_size
=
16
self
.
group_size
=
16
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
if
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
return
75
return
80
return
100
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -129,120 +85,40 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -129,120 +85,40 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
global_input_scale
=
layer
.
input_global_scale
.
max
().
to
(
torch
.
float32
)
# Rename CT checkpoint names to standardized names
layer
.
input_global_scale
=
Parameter
(
global_input_scale
,
requires_grad
=
False
)
layer
.
weight
=
layer
.
weight_packed
del
layer
.
weight_packed
# Process global scales (CT stores as divisors, i.e. 1/scale)
input_global_scale_inv
=
layer
.
input_global_scale
.
max
().
to
(
torch
.
float32
)
layer
.
input_global_scale
=
Parameter
(
(
1.0
/
input_global_scale_inv
).
to
(
torch
.
float32
),
requires_grad
=
False
)
weight_global_scale
=
layer
.
weight_global_scale
.
max
().
to
(
torch
.
float32
)
layer
.
weight_global_scale
=
Parameter
(
layer
.
weight_global_scale
=
Parameter
(
layer
.
weight_global_scale
.
max
().
to
(
torch
.
float32
)
,
requires_grad
=
False
1.0
/
weight_global_scale
,
requires_grad
=
False
)
)
if
self
.
backend
==
"flashinfer-trtllm"
:
# Pre-compute alpha and inverse for runtime quantization
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
layer
.
input_global_scale_inv
=
Parameter
(
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
input_global_scale_inv
,
requires_grad
=
False
# layout but we use our own quantization so we have to call
)
# shuffles ourselves.
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
weight
=
layer
.
weight_packed
.
data
weight_scale
=
layer
.
weight_scale
.
data
epilogue_tile_m
=
128
weight
=
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
weight_scale
=
(
shuffle_matrix_sf_a
(
weight_scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
weight_scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
)
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
else
:
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
if
self
.
backend
==
"fbgemm"
:
swizzled_weight_scale
=
swizzled_weight_scale
.
view
(
-
1
).
view
(
torch
.
uint8
)
layer
.
weight_scale
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
# Pad weights for CUTLASS/FlashInfer kernel alignment (K and N
# divisible by 32). fbgemm has its own layout requirements.
if
self
.
backend
in
(
"cutlass"
,
"flashinfer-cutlass"
):
weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
layer
.
weight_packed
.
data
)
layer
.
weights_padding_cols
=
weights_padding_cols
layer
.
weight_packed
=
Parameter
(
weight
,
requires_grad
=
False
)
else
:
layer
.
weights_padding_cols
=
0
layer
.
weight_packed
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
alpha
=
Parameter
(
1
/
(
layer
.
input_global_scale
*
layer
.
weight_global_scale
),
layer
.
input_global_scale
*
layer
.
weight_global_scale
,
requires_grad
=
False
requires_grad
=
False
,
)
)
# Convert layer to NVFP4 linear kernel format
convert_to_nvfp4_linear_kernel_format
(
self
.
backend
,
layer
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
return
apply_nvfp4_linear
(
out
=
run_nvfp4_emulations
(
x
=
x
,
input_global_scale
=
layer
.
input_global_scale
,
weight
=
layer
.
weight_packed
,
weight_scale_swizzled
=
layer
.
weight_scale
,
weight_global_scale
=
layer
.
weight_global_scale
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
output_dtype
=
x
.
dtype
output_size
=
layer
.
output_size_per_partition
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
,
is_sf_swizzled_layout
=
True
,
backend
=
self
.
backend
,
backend
=
self
.
backend
,
layer
=
layer
,
x
=
x
,
bias
=
bias
,
)
)
# Pad activations to match weight K-dimension padding
weights_padding_cols
=
getattr
(
layer
,
"weights_padding_cols"
,
0
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
weights_padding_cols
)
mm_args
=
(
x_fp4
,
layer
.
weight_packed
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
)
if
self
.
backend
.
startswith
(
"flashinfer-"
):
backend_name
=
self
.
backend
[
len
(
"flashinfer-"
)
:]
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
backend_name
)
elif
self
.
backend
==
"fbgemm"
:
out
=
torch
.
ops
.
fbgemm
.
f4f4bf16
(
x_fp4
,
layer
.
weight_packed
,
x_blockscale
.
view
(
-
1
).
view
(
torch
.
uint8
),
layer
.
weight_scale
,
layer
.
alpha
,
use_mx
=
False
,
).
to
(
output_dtype
)
else
:
assert
self
.
backend
==
"cutlass"
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
# Slice output to remove N-dimension padding
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
vllm/model_executor/layers/quantization/modelopt.py
View file @
67ebaff5
...
@@ -5,12 +5,9 @@ from fnmatch import fnmatch
...
@@ -5,12 +5,9 @@ from fnmatch import fnmatch
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
...
@@ -66,24 +63,19 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -66,24 +63,19 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
get_marlin_input_dtype
,
get_marlin_input_dtype
,
)
)
from
vllm.model_executor.layers.quantization.utils.
marlin
_utils
_fp4
import
(
from
vllm.model_executor.layers.quantization.utils.
nvfp4
_utils
import
(
apply_fp4_
marlin_
linear
,
apply_
nv
fp4_linear
,
is_fp4_marlin_supported
,
convert_to_nvfp4_linear_kernel_format
,
prepare_fp4_layer_for_marlin
,
select_nvfp4_linear_backend
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
cutlass_fp4_supported
,
is_layer_skipped
,
is_layer_skipped
,
kFp8DynamicTokenSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
kFp8StaticTokenSym
,
kNvfp4Dynamic
,
kNvfp4Dynamic
,
kNvfp4Static
,
kNvfp4Static
,
pad_nvfp4_activation_for_cutlass
,
pad_nvfp4_weight_for_cutlass
,
slice_nvfp4_output
,
swizzle_blockscale
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
...
@@ -96,11 +88,6 @@ from vllm.model_executor.parameter import (
...
@@ -96,11 +88,6 @@ from vllm.model_executor.parameter import (
PerTensorScaleParameter
,
PerTensorScaleParameter
,
)
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
flashinfer_scaled_fp4_mm
,
has_flashinfer
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.models.utils
import
WeightsMapper
...
@@ -498,7 +485,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -498,7 +485,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight
=
layer
.
weight
max_w_scale
=
layer
.
weight_scale
.
max
()
max_w_scale
=
layer
.
weight_scale
.
max
()
if
not
(
layer
.
weight_scale
==
layer
.
weight_scale
[
0
]).
all
():
if
not
(
layer
.
weight_scale
==
layer
.
weight_scale
[
0
]).
all
():
...
@@ -580,7 +567,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
...
@@ -580,7 +567,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
...
@@ -681,7 +668,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
...
@@ -681,7 +668,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp.
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
...
@@ -1108,32 +1095,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1108,32 +1095,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
)
->
None
:
def
__init__
(
self
,
quant_config
:
ModelOptNvFp4Config
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
marlin_input_dtype
=
None
self
.
marlin_input_dtype
=
None
self
.
backend
=
select_nvfp4_linear_backend
()
self
.
backend
=
"none"
if
envs
.
VLLM_NVFP4_GEMM_BACKEND
is
None
:
if
current_platform
.
has_device_capability
(
100
)
and
has_flashinfer
():
self
.
backend
=
"flashinfer-cutlass"
elif
cutlass_fp4_supported
():
self
.
backend
=
"cutlass"
elif
is_fp4_marlin_supported
():
self
.
backend
=
"marlin"
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
.
startswith
(
"flashinfer-"
):
self
.
backend
=
envs
.
VLLM_NVFP4_GEMM_BACKEND
assert
has_flashinfer
(),
f
"FlashInfer is required for
{
self
.
backend
}
"
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
==
"cutlass"
:
self
.
backend
=
"cutlass"
assert
cutlass_fp4_supported
(),
f
"Cutlass is required for
{
self
.
backend
}
"
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
==
"marlin"
:
self
.
backend
=
"marlin"
assert
is_fp4_marlin_supported
(),
f
"Marlin is required for
{
self
.
backend
}
"
if
self
.
backend
==
"none"
:
raise
ValueError
(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger
.
info_once
(
f
"Using
{
self
.
backend
}
for NVFP4 GEMM"
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -1181,19 +1143,19 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1181,19 +1143,19 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
# Input
Weight
Scale
# Input
Global
Scale
input_scale
=
PerTensorScaleParameter
(
input_
global_
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_
global_
scale
)
#
Global Weight
Scale
#
Weight Global
Scale
weight_scale
_2
=
PerTensorScaleParameter
(
weight_
global_
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_scale
_2
)
layer
.
register_parameter
(
"weight_scale_2"
,
weight_
global_
scale
)
# Per Block Weight Scale
# Per Block Weight Scale
weight_scale
=
ModelWeightParameter
(
weight_scale
=
ModelWeightParameter
(
...
@@ -1209,65 +1171,25 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1209,65 +1171,25 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# global scales:
# Rename ModelOpt checkpoint names to standardized names
input_scale_2
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
input_global_scale
=
layer
.
input_scale
.
max
().
to
(
torch
.
float32
)
layer
.
input_scale
=
Parameter
(
input_scale_2
,
requires_grad
=
False
)
layer
.
input_global_scale
=
Parameter
(
input_global_scale
,
requires_grad
=
False
)
del
layer
.
input_scale
weight_scale_2
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
weight_global_scale
=
layer
.
weight_scale_2
.
max
().
to
(
torch
.
float32
)
layer
.
weight_scale_2
=
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
weight_global_scale
=
Parameter
(
weight_global_scale
,
requires_grad
=
False
)
del
layer
.
weight_scale_2
# Pre-compute alpha and inverse for runtime quantization
layer
.
alpha
=
Parameter
(
layer
.
alpha
=
Parameter
(
layer
.
input_scale
*
layer
.
weight_scale_2
,
requires_grad
=
False
layer
.
input_global_scale
*
layer
.
weight_global_scale
,
requires_grad
=
False
)
# Calculate `1 / input_scale` so that we don't need to do so at runtime
layer
.
input_scale_inv
=
Parameter
(
(
1
/
layer
.
input_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
)
)
layer
.
input_global_scale_inv
=
Parameter
(
# Swizzle the weight blockscale.
(
1.0
/
layer
.
input_global_scale
).
to
(
torch
.
float32
),
requires_grad
=
False
# contracting dimension is input dimension
# block_size = 16;
assert
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"Weight Block scale must be represented as FP8-E4M3"
)
)
if
self
.
backend
==
"marlin"
:
# Convert layer to NVFP4 linear kernel format
prepare_fp4_layer_for_marlin
(
layer
)
convert_to_nvfp4_linear_kernel_format
(
self
.
backend
,
layer
)
del
layer
.
alpha
del
layer
.
input_scale
elif
self
.
backend
==
"flashinfer-trtllm"
:
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
weight
=
layer
.
weight
.
data
weight_scale
=
layer
.
weight_scale
.
data
epilogue_tile_m
=
128
weight
=
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
weight_scale
=
(
shuffle_matrix_sf_a
(
weight_scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
weight_scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
)
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
else
:
# Swizzle block scales and pad the packed NVFP4 weights for kernel
# alignment (CUTLASS/FlashInfer require K and N divisible by 32).
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
layer
.
weight_scale
=
Parameter
(
swizzled_weight_scale
,
requires_grad
=
False
)
weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
layer
.
weight
.
data
)
layer
.
weights_padding_cols
=
weights_padding_cols
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1275,63 +1197,13 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1275,63 +1197,13 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
backend
==
"marlin"
:
return
apply_nvfp4_linear
(
return
apply_fp4_marlin_linear
(
backend
=
self
.
backend
,
input
=
x
,
layer
=
layer
,
weight
=
layer
.
weight
,
x
=
x
,
weight_scale
=
layer
.
weight_scale
,
bias
=
bias
,
weight_scale_2
=
layer
.
weight_scale_2
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
input_dtype
=
self
.
marlin_input_dtype
,
)
output_dtype
=
x
.
dtype
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
self
.
backend
)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
layer
.
weight
.
dtype
==
torch
.
uint8
assert
x_blockscale
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
assert
layer
.
alpha
.
dtype
==
torch
.
float32
# Pad activations to match weight K-dimension padding
weights_padding_cols
=
getattr
(
layer
,
"weights_padding_cols"
,
0
)
output_size
=
layer
.
output_size_per_partition
output_shape
=
[
x
.
shape
[
0
],
output_size
]
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
weights_padding_cols
)
mm_args
=
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale
,
layer
.
alpha
,
output_dtype
,
)
)
if
self
.
backend
.
startswith
(
"flashinfer-"
):
backend_name
=
self
.
backend
[
len
(
"flashinfer-"
)
:]
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
backend_name
)
else
:
assert
self
.
backend
==
"cutlass"
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
# Slice output to remove N-dimension padding
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
"""
"""
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
67ebaff5
...
@@ -15,11 +15,13 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -15,11 +15,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig
,
FusedMoEParallelConfig
,
RoutingMethodType
,
RoutingMethodType
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
swizzle_blockscale
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
QuantKey
,
kNvfp4Dynamic
,
kNvfp4Dynamic
,
kNvfp4Static
,
kNvfp4Static
,
swizzle_blockscale
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
(
from
vllm.utils.flashinfer
import
(
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py
View file @
67ebaff5
...
@@ -92,7 +92,7 @@ def apply_fp4_marlin_linear(
...
@@ -92,7 +92,7 @@ def apply_fp4_marlin_linear(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_scale
_2
:
torch
.
Tensor
|
None
,
weight_
global_
scale
:
torch
.
Tensor
|
None
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_n
:
int
,
size_k
:
int
,
size_k
:
int
,
...
@@ -112,7 +112,7 @@ def apply_fp4_marlin_linear(
...
@@ -112,7 +112,7 @@ def apply_fp4_marlin_linear(
inputs
=
reshaped_x
inputs
=
reshaped_x
a_scales
=
None
a_scales
=
None
is_nvfp4
=
weight_scale
_2
is
not
None
is_nvfp4
=
weight_
global_
scale
is
not
None
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
is_nvfp4
:
if
is_nvfp4
:
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
...
@@ -128,7 +128,7 @@ def apply_fp4_marlin_linear(
...
@@ -128,7 +128,7 @@ def apply_fp4_marlin_linear(
b_bias
=
bias
,
b_bias
=
bias
,
b_scales
=
weight_scale
,
b_scales
=
weight_scale
,
a_scales
=
a_scales
,
a_scales
=
a_scales
,
global_scale
=
weight_scale
_2
,
global_scale
=
weight_
global_
scale
,
b_zeros
=
None
,
b_zeros
=
None
,
g_idx
=
None
,
g_idx
=
None
,
perm
=
None
,
perm
=
None
,
...
@@ -154,7 +154,7 @@ def prepare_fp4_layer_for_marlin(
...
@@ -154,7 +154,7 @@ def prepare_fp4_layer_for_marlin(
"performance for compute-heavy workloads."
"performance for compute-heavy workloads."
)
)
is_nvfp4
=
hasattr
(
layer
,
"weight_scale
_2
"
)
is_nvfp4
=
hasattr
(
layer
,
"weight_
global_
scale"
)
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
:
if
is_nvfp4
:
if
is_nvfp4
:
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
raise
RuntimeError
(
"NVFP4 weight + INT8/FP8 activation is not supported."
)
...
@@ -210,9 +210,11 @@ def prepare_fp4_layer_for_marlin(
...
@@ -210,9 +210,11 @@ def prepare_fp4_layer_for_marlin(
weight_scale
=
nvfp4_marlin_process_scales
(
weight_scale
)
weight_scale
=
nvfp4_marlin_process_scales
(
weight_scale
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_scale_2
=
layer
.
weight_scale_2
.
to
(
param_dtype
)
weight_global_scale
=
layer
.
weight_global_scale
.
to
(
param_dtype
)
weight_scale_2
=
nvfp4_marlin_process_global_scale
(
weight_scale_2
)
weight_global_scale
=
nvfp4_marlin_process_global_scale
(
weight_global_scale
)
layer
.
weight_scale_2
=
torch
.
nn
.
Parameter
(
weight_scale_2
,
requires_grad
=
False
)
layer
.
weight_global_scale
=
torch
.
nn
.
Parameter
(
weight_global_scale
,
requires_grad
=
False
)
else
:
else
:
weight_scale
=
mxfp4_marlin_process_scales
(
weight_scale
=
mxfp4_marlin_process_scales
(
weight_scale
,
input_dtype
=
input_dtype
weight_scale
,
input_dtype
=
input_dtype
...
...
vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py
View file @
67ebaff5
...
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
...
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
is_fp4_marlin_supported
,
is_fp4_marlin_supported
,
)
)
from
vllm.model_executor.layers.quantization.utils.
quant
_utils
import
(
from
vllm.model_executor.layers.quantization.utils.
nvfp4
_utils
import
(
cutlass_fp4_supported
,
cutlass_fp4_supported
,
)
)
...
...
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py
0 → 100644
View file @
67ebaff5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
enum
import
Enum
import
torch
import
vllm.envs
as
envs
from
vllm._custom_ops
import
(
cutlass_scaled_fp4_mm
,
cutlass_scaled_mm_supports_fp4
,
scaled_fp4_quant
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
apply_fp4_marlin_linear
,
is_fp4_marlin_supported
,
prepare_fp4_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
run_nvfp4_emulations
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
flashinfer_scaled_fp4_mm
,
has_flashinfer
from
vllm.utils.math_utils
import
round_up
logger
=
init_logger
(
__name__
)
class
NvFp4LinearBackend
(
Enum
):
VLLM_CUTLASS
=
"cutlass"
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
FLASHINFER_TRTLLM
=
"flashinfer-trtllm"
FLASHINFER_CUDNN
=
"flashinfer-cudnn"
FBGEMM
=
"fbgemm"
MARLIN
=
"marlin"
EMULATION
=
"emulation"
def
select_nvfp4_linear_backend
()
->
NvFp4LinearBackend
:
"""
Select the best available NVFP4 GEMM backend based on environment
configuration and platform capabilities.
"""
backend
:
NvFp4LinearBackend
|
None
=
None
if
envs
.
VLLM_USE_FBGEMM
:
try
:
import
fbgemm_gpu
# noqa: F401
except
ImportError
as
exc
:
raise
ImportError
(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
)
from
exc
backend
=
NvFp4LinearBackend
.
FBGEMM
elif
envs
.
VLLM_USE_NVFP4_CT_EMULATIONS
:
backend
=
NvFp4LinearBackend
.
EMULATION
elif
envs
.
VLLM_NVFP4_GEMM_BACKEND
is
None
:
# Auto-select best available backend
if
current_platform
.
has_device_capability
(
100
)
and
has_flashinfer
():
backend
=
NvFp4LinearBackend
.
FLASHINFER_CUTLASS
elif
cutlass_fp4_supported
():
backend
=
NvFp4LinearBackend
.
VLLM_CUTLASS
elif
is_fp4_marlin_supported
():
backend
=
NvFp4LinearBackend
.
MARLIN
else
:
backend
=
NvFp4LinearBackend
(
envs
.
VLLM_NVFP4_GEMM_BACKEND
)
# Validate that the backend is supported
if
backend
in
(
NvFp4LinearBackend
.
FLASHINFER_CUTLASS
,
NvFp4LinearBackend
.
FLASHINFER_TRTLLM
,
NvFp4LinearBackend
.
FLASHINFER_CUDNN
,
):
assert
has_flashinfer
(),
f
"FlashInfer is required for
{
backend
}
"
elif
backend
==
NvFp4LinearBackend
.
VLLM_CUTLASS
:
assert
cutlass_fp4_supported
(),
f
"Cutlass is required for
{
backend
}
"
elif
backend
==
NvFp4LinearBackend
.
MARLIN
:
assert
is_fp4_marlin_supported
(),
f
"Marlin is required for
{
backend
}
"
elif
backend
is
None
:
raise
ValueError
(
f
"No NVFP4 GEMM backend selected, "
f
"available backends:
{
list
(
NvFp4LinearBackend
)
}
"
)
logger
.
info_once
(
f
"Using
{
backend
}
for NVFP4 GEMM"
)
return
backend
def
prepare_weights_for_nvfp4_flashinfer_trtllm
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prepare weights and scales for FlashInfer TRTLLM FP4 GEMM."""
from
flashinfer
import
shuffle_matrix_a
,
shuffle_matrix_sf_a
epilogue_tile_m
=
128
shuffled_weight
=
shuffle_matrix_a
(
weight
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
shuffled_weight_scale
=
(
shuffle_matrix_sf_a
(
weight_scale
.
view
(
torch
.
uint8
),
epilogue_tile_m
)
.
reshape
(
weight_scale
.
shape
)
.
view
(
torch
.
float8_e4m3fn
)
)
return
shuffled_weight
,
shuffled_weight_scale
def
prepare_weights_for_nvfp4_cutlass
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
int
]:
"""
Prepare weights and scales for CUTLASS/FlashInfer-CUTLASS FP4 GEMM.
This involves padding weights for alignment (K and N divisible by 32)
"""
swizzled_weight_scale
=
swizzle_blockscale
(
weight_scale
)
padded_weight
,
weights_padding_cols
=
pad_nvfp4_weight_for_cutlass
(
weight
)
return
padded_weight
,
swizzled_weight_scale
,
weights_padding_cols
def
prepare_weights_for_nvfp4_fbgemm
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Prepare weights and scales for FBGEMM FP4 GEMM."""
swizzled_weight_scale
=
swizzle_blockscale
(
weight_scale
)
swizzled_weight_scale
=
swizzled_weight_scale
.
view
(
-
1
).
view
(
torch
.
uint8
)
return
weight
,
swizzled_weight_scale
def
convert_to_nvfp4_linear_kernel_format
(
backend
:
NvFp4LinearBackend
,
layer
:
torch
.
nn
.
Module
,
)
->
None
:
"""Convert layer to NVFP4 linear kernel format."""
assert
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"Weight Block scale must be represented as FP8-E4M3"
)
# Default to no padding
layer
.
weights_padding_cols
=
0
if
backend
==
NvFp4LinearBackend
.
MARLIN
:
prepare_fp4_layer_for_marlin
(
layer
)
elif
backend
==
NvFp4LinearBackend
.
FLASHINFER_TRTLLM
:
weight
,
weight_scale
=
prepare_weights_for_nvfp4_flashinfer_trtllm
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
elif
backend
==
NvFp4LinearBackend
.
FBGEMM
:
weight
,
weight_scale
=
prepare_weights_for_nvfp4_fbgemm
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
elif
(
backend
==
NvFp4LinearBackend
.
VLLM_CUTLASS
or
backend
==
NvFp4LinearBackend
.
FLASHINFER_CUTLASS
):
weight
,
weight_scale
,
weights_padding_cols
=
prepare_weights_for_nvfp4_cutlass
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weights_padding_cols
=
weights_padding_cols
def
apply_nvfp4_linear
(
backend
:
NvFp4LinearBackend
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Apply NVFP4 linear transformation using the specified backend.
"""
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_global_scale
=
layer
.
weight_global_scale
input_global_scale_inv
=
layer
.
input_global_scale_inv
alpha
=
layer
.
alpha
output_size
=
layer
.
output_size_per_partition
input_size
=
layer
.
input_size_per_partition
if
backend
==
NvFp4LinearBackend
.
MARLIN
:
return
apply_fp4_marlin_linear
(
input
=
x
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_global_scale
=
weight_global_scale
,
workspace
=
layer
.
workspace
,
size_n
=
output_size
,
size_k
=
input_size
,
bias
=
bias
,
)
elif
backend
==
NvFp4LinearBackend
.
EMULATION
:
out
=
run_nvfp4_emulations
(
x
=
x
,
input_global_scale
=
input_global_scale_inv
,
weight
=
weight
,
weight_scale_swizzled
=
weight_scale
,
weight_global_scale
=
weight_global_scale
,
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
output_dtype
=
x
.
dtype
output_shape
=
[
*
x
.
shape
[:
-
1
],
output_size
]
# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
input_global_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
backend
.
value
)
# Validate dtypes
assert
x_fp4
.
dtype
==
torch
.
uint8
assert
weight
.
dtype
==
torch
.
uint8
assert
x_blockscale
.
dtype
==
torch
.
float8_e4m3fn
# weight_scale is fp8 for most backends, but uint8 for fbgemm
assert
weight_scale
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
uint8
)
assert
alpha
.
dtype
==
torch
.
float32
# Pad activations to match weight K-dimension padding
weights_padding_cols
=
getattr
(
layer
,
"weights_padding_cols"
,
0
)
x_fp4
=
pad_nvfp4_activation_for_cutlass
(
x_fp4
,
weights_padding_cols
)
# Prepare args for the matmul
mm_args
=
(
x_fp4
,
weight
,
x_blockscale
,
weight_scale
,
alpha
,
output_dtype
,
)
# Call the appropriate backend
if
backend
.
value
.
startswith
(
"flashinfer-"
):
backend_name
=
backend
.
value
[
len
(
"flashinfer-"
)
:]
out
=
flashinfer_scaled_fp4_mm
(
*
mm_args
,
backend
=
backend_name
)
elif
backend
==
NvFp4LinearBackend
.
FBGEMM
:
out
=
torch
.
ops
.
fbgemm
.
f4f4bf16
(
x_fp4
,
weight
,
x_blockscale
.
view
(
-
1
).
view
(
torch
.
uint8
),
weight_scale
,
alpha
,
use_mx
=
False
,
).
to
(
output_dtype
)
else
:
assert
backend
==
NvFp4LinearBackend
.
VLLM_CUTLASS
out
=
cutlass_scaled_fp4_mm
(
*
mm_args
)
# Slice output to remove N-dimension padding
out
=
slice_nvfp4_output
(
out
,
output_size
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
def
swizzle_blockscale
(
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format."
)
scale_ndim
=
scale
.
ndim
if
scale_ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
# (1, M, K)
assert
scale
.
ndim
==
3
,
"Expected a 2-D or 3-D tensor for block scales."
B
,
M
,
K
=
scale
.
shape
M_padded
=
round_up
(
M
,
128
)
K_padded
=
round_up
(
K
,
4
)
padded
=
torch
.
zeros
(
(
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
,
device
=
scale
.
device
)
padded
[:
B
,
:
M
,
:
K
]
=
scale
# Reshape / permute to the layout required by the kernel.
padded
=
padded
.
reshape
(
B
,
M_padded
//
128
,
4
,
32
,
K_padded
//
4
,
4
)
swizzled
=
padded
.
permute
(
0
,
1
,
4
,
3
,
2
,
5
).
contiguous
().
cuda
()
if
scale_ndim
==
2
:
return
swizzled
.
reshape
(
M_padded
,
K_padded
)
return
swizzled
.
reshape
(
B
,
M_padded
,
K_padded
)
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
def
pad_nvfp4_weight_for_cutlass
(
weight
:
torch
.
Tensor
,
alignment
:
int
=
32
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows
=
weight
.
shape
[
0
]
# Pad N dimension (rows) if not aligned
if
weight_current_rows
%
alignment
!=
0
:
total_rows
=
round_up
(
weight_current_rows
,
alignment
)
pad_rows
=
total_rows
-
weight_current_rows
weight
=
torch
.
nn
.
functional
.
pad
(
weight
,
(
0
,
0
,
0
,
pad_rows
)).
contiguous
()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes
=
weight
.
shape
[
1
]
weight_current_col_elements
=
weight_current_col_bytes
*
2
weights_padding_bytes
=
0
if
weight_current_col_elements
%
alignment
!=
0
:
total_cols
=
round_up
(
weight_current_col_elements
,
alignment
)
pad_cols
=
total_cols
-
weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes
=
pad_cols
//
2
weight
=
torch
.
nn
.
functional
.
pad
(
weight
,
(
0
,
pad_bytes
,
0
,
0
)).
contiguous
()
weights_padding_bytes
=
pad_bytes
return
weight
,
weights_padding_bytes
def
pad_nvfp4_activation_for_cutlass
(
x_fp4
:
torch
.
Tensor
,
weights_padding_bytes
:
int
,
)
->
torch
.
Tensor
:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if
weights_padding_bytes
>
0
:
return
torch
.
nn
.
functional
.
pad
(
x_fp4
,
(
0
,
weights_padding_bytes
)).
contiguous
()
return
x_fp4
def
slice_nvfp4_output
(
out
:
torch
.
Tensor
,
output_size
:
int
,
)
->
torch
.
Tensor
:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if
out
.
shape
[
-
1
]
!=
output_size
:
return
out
[...,
:
output_size
].
contiguous
()
return
out
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
67ebaff5
...
@@ -11,7 +11,6 @@ import numpy
...
@@ -11,7 +11,6 @@ import numpy
import
torch
import
torch
from
torch
import
fx
from
torch
import
fx
from
vllm._custom_ops
import
cutlass_scaled_mm_supports_fp4
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
@@ -768,60 +767,6 @@ def awq_pack(
...
@@ -768,60 +767,6 @@ def awq_pack(
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
return
pack_cols
(
q_w
,
num_bits
,
size_k
,
size_n
)
def
swizzle_blockscale
(
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Pad and block-interleave the FP4 block-scales so that they match the data
layout expected by the CUTLASS / FlashInfer kernels.
Parameters
----------
scale: torch.Tensor
Returns
-------
torch.Tensor
The swizzled tensor with the same logical shape as *scale*.
"""
assert
scale
.
dtype
==
torch
.
float8_e4m3fn
,
(
"swizzle_blockscale expects the input tensor to be in "
"torch.float8_e4m3fn format."
)
scale_ndim
=
scale
.
ndim
if
scale_ndim
==
2
:
scale
=
scale
.
unsqueeze
(
0
)
# (1, M, K)
assert
scale
.
ndim
==
3
,
"Expected a 2-D or 3-D tensor for block scales."
B
,
M
,
K
=
scale
.
shape
def
_round_up
(
x
:
int
,
m
:
int
)
->
int
:
return
(
x
+
m
-
1
)
//
m
*
m
M_padded
=
_round_up
(
M
,
128
)
K_padded
=
_round_up
(
K
,
4
)
padded
=
torch
.
zeros
(
(
B
,
M_padded
,
K_padded
),
dtype
=
scale
.
dtype
,
device
=
scale
.
device
)
padded
[:
B
,
:
M
,
:
K
]
=
scale
# Reshape / permute to the layout required by the kernel.
padded
=
padded
.
reshape
(
B
,
M_padded
//
128
,
4
,
32
,
K_padded
//
4
,
4
)
swizzled
=
padded
.
permute
(
0
,
1
,
4
,
3
,
2
,
5
).
contiguous
().
cuda
()
if
scale_ndim
==
2
:
return
swizzled
.
reshape
(
M_padded
,
K_padded
)
return
swizzled
.
reshape
(
B
,
M_padded
,
K_padded
)
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
def
convert_bf16_scales_to_fp8
(
def
convert_bf16_scales_to_fp8
(
quant_fp8
:
Callable
,
scales
:
torch
.
Tensor
quant_fp8
:
Callable
,
scales
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -868,70 +813,3 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
...
@@ -868,70 +813,3 @@ def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tens
t
|=
((
nib
-
8
)
&
0xF
)
<<
shift
t
|=
((
nib
-
8
)
&
0xF
)
<<
shift
return
t
return
t
def
round_up
(
x
:
int
,
m
:
int
)
->
int
:
"""Round up x to the nearest multiple of m."""
return
(
x
+
m
-
1
)
//
m
*
m
def
pad_nvfp4_weight_for_cutlass
(
weight
:
torch
.
Tensor
,
alignment
:
int
=
32
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
"""
Pad packed NVFP4 weights so that both N (rows) and K (columns) satisfy
the alignment constraints required by CUTLASS / FlashInfer FP4 kernels.
CUTLASS FP4 kernel requires both K and N matrix dimensions to be divisible
by 32 for aligned memory access and efficient tensor core operations.
"""
weight_current_rows
=
weight
.
shape
[
0
]
# Pad N dimension (rows) if not aligned
if
weight_current_rows
%
alignment
!=
0
:
total_rows
=
round_up
(
weight_current_rows
,
alignment
)
pad_rows
=
total_rows
-
weight_current_rows
weight
=
torch
.
nn
.
functional
.
pad
(
weight
,
(
0
,
0
,
0
,
pad_rows
)).
contiguous
()
# Check K dimension alignment
# 2 FP4 items are packed per byte in the input dimension
weight_current_col_bytes
=
weight
.
shape
[
1
]
weight_current_col_elements
=
weight_current_col_bytes
*
2
weights_padding_bytes
=
0
if
weight_current_col_elements
%
alignment
!=
0
:
total_cols
=
round_up
(
weight_current_col_elements
,
alignment
)
pad_cols
=
total_cols
-
weight_current_col_elements
# Convert from FP4 element count to bytes (2 FP4 values per byte)
# pad_cols is always even since alignment=32 and current elements are even
pad_bytes
=
pad_cols
//
2
weight
=
torch
.
nn
.
functional
.
pad
(
weight
,
(
0
,
pad_bytes
,
0
,
0
)).
contiguous
()
weights_padding_bytes
=
pad_bytes
return
weight
,
weights_padding_bytes
def
pad_nvfp4_activation_for_cutlass
(
x_fp4
:
torch
.
Tensor
,
weights_padding_bytes
:
int
,
)
->
torch
.
Tensor
:
"""
Pad packed FP4 activations to match the K-dimension padding applied to weights.
The padding is in bytes (tensor dimension), not FP4 elements.
"""
if
weights_padding_bytes
>
0
:
return
torch
.
nn
.
functional
.
pad
(
x_fp4
,
(
0
,
weights_padding_bytes
)).
contiguous
()
return
x_fp4
def
slice_nvfp4_output
(
out
:
torch
.
Tensor
,
output_size
:
int
,
)
->
torch
.
Tensor
:
"""
Slice the output tensor to remove padding in N dimension if weight was padded.
"""
if
out
.
shape
[
-
1
]
!=
output_size
:
return
out
[...,
:
output_size
].
contiguous
()
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment