Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6bc7b573
Unverified
Commit
6bc7b573
authored
Jun 16, 2025
by
Dipika Sikka
Committed by
GitHub
Jun 16, 2025
Browse files
[Quantization] Remove FP4 emulation; Fall-back to marlin for device < 100 (#19563)
parent
90f9c2eb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
79 additions
and
60 deletions
+79
-60
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+7
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+8
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
...pressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
+13
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+22
-57
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
...ecutor/layers/quantization/utils/nvfp4_emulation_utils.py
+29
-0
No files found.
tests/quantization/test_compressed_tensors.py
View file @
6bc7b573
...
...
@@ -667,7 +667,13 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
scheme
)
if
isinstance
(
qkv_proj
.
scheme
,
scheme
)
or
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Fp4
)
and
not
CompressedTensorsW4A4Fp4
.
cutlass_fp4_supported
():
assert
True
else
:
raise
AssertionError
(
"FP4 Scheme Mismatch"
)
assert
qkv_proj
.
scheme
.
group_size
==
16
llm
.
apply_model
(
check_model
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
6bc7b573
...
...
@@ -374,7 +374,14 @@ class CompressedTensorsConfig(QuantizationConfig):
if
is_activation_quantization_format
(
self
.
quant_format
):
if
self
.
_is_fp4a4_nvfp4
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A4Fp4
()
if
CompressedTensorsW4A4Fp4
.
cutlass_fp4_supported
():
return
CompressedTensorsW4A4Fp4
()
else
:
logger
.
warning_once
(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4."
)
return
CompressedTensorsW4A16Fp4
(
has_input_global_scale
=
True
)
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py
View file @
6bc7b573
...
...
@@ -18,7 +18,8 @@ __all__ = ["CompressedTensorsW4A16Fp4"]
class
CompressedTensorsW4A16Fp4
(
CompressedTensorsScheme
):
def
__init__
(
self
):
def
__init__
(
self
,
has_input_global_scale
:
bool
=
False
):
self
.
has_input_global_scale
=
has_input_global_scale
self
.
group_size
=
16
@
classmethod
...
...
@@ -64,6 +65,13 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
if
self
.
has_input_global_scale
:
input_global_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_global_scale"
,
input_global_scale
)
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Process parameters for marlin repacking
...
...
@@ -77,6 +85,10 @@ class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
requires_grad
=
False
)
del
layer
.
weight_global_scale
if
self
.
has_input_global_scale
:
layer
.
input_global_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_global_scale
.
data
,
requires_grad
=
False
)
prepare_fp4_layer_for_marlin
(
layer
)
def
apply_weights
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
6bc7b573
...
...
@@ -9,8 +9,6 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils
import
(
# noqa: E501
dequantize_to_dtype
,
ref_nvfp4_quant
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
)
...
...
@@ -21,53 +19,23 @@ logger = init_logger(__name__)
__all__
=
[
"CompressedTensorsW4A4Fp4"
]
def
cutlass_fp4_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
cutlass_scaled_mm_supports_fp4
(
capability
)
class
CompressedTensorsW4A4Fp4
(
CompressedTensorsScheme
):
def
__init__
(
self
):
self
.
group_size
=
16
self
.
cutlass_nvfp4_supported
=
cutlass_fp4_supported
()
if
not
self
.
cutlass_nvfp4_supported
:
logger
.
warning
(
"Current platform does not support cutlass NVFP4."
" Running emulations."
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# dont restrict as emulations
return
80
def
run_nvfp4_emulations
(
self
,
x
:
torch
.
Tensor
,
layer
):
x_m
,
x_k
=
x
.
shape
output_dtype
=
x
.
dtype
# quantize input to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
ref_nvfp4_quant
(
x
,
layer
.
input_global_scale
,
self
.
group_size
)
return
100
# dequantize input
x_fp4
=
x_fp4
.
reshape
(
x_m
,
x_k
//
self
.
group_size
,
self
.
group_size
)
x_blockscale
=
x_blockscale
.
unsqueeze
(
-
1
)
/
layer
.
input_global_scale
x_dq
=
(
x_fp4
*
x_blockscale
).
reshape
(
x_m
,
x_k
).
to
(
output_dtype
)
del
x_fp4
,
x_blockscale
# dequantize weight
w_fp4
=
layer
.
weight
.
data
.
view
(
torch
.
uint8
)
w_blockscale
=
layer
.
weight_scale_swizzled
.
data
w_global_scale
=
layer
.
weight_global_scale
w_dq
=
dequantize_to_dtype
(
w_fp4
,
w_blockscale
,
w_global_scale
,
output_dtype
,
x
.
device
,
self
.
group_size
)
# matmul
out
=
torch
.
matmul
(
x_dq
,
w_dq
.
t
())
del
w_dq
,
x_dq
return
out
@
classmethod
def
cutlass_fp4_supported
(
cls
)
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
(
# noqa: E501
)
return
cutlass_scaled_mm_supports_fp4
(
capability
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
list
[
int
],
...
...
@@ -152,27 +120,24 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# required by cutlass kernel; need Parameter, not ModelWeightParameter
layer
.
weight
=
Parameter
(
layer
.
weight_packed
.
data
,
requires_grad
=
False
)
if
self
.
cutlass_nvfp4_supported
:
layer
.
alpha
=
Parameter
(
layer
.
input_global_scale
*
layer
.
weight_global_scale
,
requires_grad
=
False
)
layer
.
alpha
=
Parameter
(
layer
.
input_global_scale
*
layer
.
weight_global_scale
,
requires_grad
=
False
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
cutlass_nvfp4_supported
:
output_dtype
=
x
.
dtype
output_shape
=
[
x
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]]
output_dtype
=
x
.
dtype
output_shape
=
[
x
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
)
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
)
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale_swizzled
,
1
/
layer
.
alpha
,
output_dtype
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
return
self
.
run_nvfp4_emulations
(
x
,
layer
)
out
=
cutlass_scaled_fp4_mm
(
x_fp4
,
layer
.
weight
,
x_blockscale
,
layer
.
weight_scale_swizzled
,
1
/
layer
.
alpha
,
output_dtype
)
if
bias
is
not
None
:
out
=
out
+
bias
return
out
.
view
(
*
output_shape
)
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py
View file @
6bc7b573
...
...
@@ -102,3 +102,32 @@ def ref_nvfp4_quant(x, global_scale, block_size):
clipped_x
=
torch
.
clamp
(
scaled_x
,
-
6.0
,
6.0
).
reshape
(
m
,
n
)
# both outputs are float32
return
cast_to_fp4
(
clipped_x
),
scale
.
squeeze
(
-
1
)
def
run_nvfp4_emulations
(
x
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale_swizzled
:
torch
.
Tensor
,
weight_global_scale
:
torch
.
Tensor
):
group_size
=
16
x_m
,
x_k
=
x
.
shape
output_dtype
=
x
.
dtype
# quantize input to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
ref_nvfp4_quant
(
x
,
input_global_scale
,
group_size
)
# dequantize input
x_fp4
=
x_fp4
.
reshape
(
x_m
,
x_k
//
group_size
,
group_size
)
x_blockscale
=
x_blockscale
.
unsqueeze
(
-
1
)
/
input_global_scale
x_dq
=
(
x_fp4
*
x_blockscale
).
reshape
(
x_m
,
x_k
).
to
(
output_dtype
)
del
x_fp4
,
x_blockscale
# dequantize weight
w_fp4
=
weight
.
data
.
view
(
torch
.
uint8
)
w_dq
=
dequantize_to_dtype
(
w_fp4
,
weight_scale_swizzled
.
data
,
weight_global_scale
,
output_dtype
,
x
.
device
,
group_size
)
# matmul
out
=
torch
.
matmul
(
x_dq
,
w_dq
.
t
())
del
w_dq
,
x_dq
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment