Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2389c12
Unverified
Commit
d2389c12
authored
Jan 20, 2026
by
Vasiliy Kuznetsov
Committed by
GitHub
Jan 20, 2026
Browse files
fp8 online quant: split out Fp8OnlineLinearMethod (#32189)
parent
22375f8d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
110 deletions
+143
-110
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+4
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+139
-109
No files found.
tests/quantization/test_fp8.py
View file @
d2389c12
...
@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
...
@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
)
def
test_
load_fp16_model
(
def
test_
online_quantization
(
vllm_runner
,
vllm_runner
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
force_marlin
:
bool
,
force_marlin
:
bool
,
...
@@ -191,6 +191,9 @@ def test_load_fp16_model(
...
@@ -191,6 +191,9 @@ def test_load_fp16_model(
llm
.
apply_model
(
check_model
)
llm
.
apply_model
(
check_model
)
outputs
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
not
is_quant_method_supported
(
"fp8"
),
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
d2389c12
...
@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
...
@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
fused_mapping
=
self
.
packed_modules_mapping
,
fused_mapping
=
self
.
packed_modules_mapping
,
):
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
quant_method
=
Fp8LinearMethod
(
self
)
if
not
self
.
is_checkpoint_fp8_serialized
:
quant_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
online_method
=
Fp8OnlineLinearMethod
(
self
)
return
quant_method
online_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
online_method
else
:
offline_method
=
Fp8LinearMethod
(
self
)
offline_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
offline_method
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped
(
if
is_layer_skipped
(
prefix
=
prefix
,
prefix
=
prefix
,
...
@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
1. Only support float8_e4m3fn data type due to the limitation of
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
Args:
...
@@ -388,54 +388,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -388,54 +388,11 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
weight_block_size
,
self
.
weight_block_size
,
)
)
# WEIGHT
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
weight
=
create_fp8_weight_parameter
(
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
)
else
:
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
# For non-serialized checkpoints, use original dtype
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
if
not
self
.
block_quant
:
if
not
self
.
block_quant
:
scale
=
create_fp8_scale_parameter
(
scale
=
create_fp8_scale_parameter
(
...
@@ -468,9 +425,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -468,9 +425,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
size_k_first
=
True
size_k_first
=
True
input_scale
=
None
input_scale
=
None
# TODO(rob): refactor block quant into separate class.
# TODO(rob): refactor block quant into separate class.
...
@@ -488,27 +442,20 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -488,27 +442,20 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
else
:
else
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
# shards in a fused module
else
:
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
if
not
self
.
use_marlin
:
weight
,
weight_scale
,
input_scale
=
(
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_tensor_strategy
(
process_fp8_weight_tensor_strategy
(
weight
,
weight
,
weight_scale
,
weight_scale
,
layer
.
logical_widths
,
layer
.
logical_widths
,
getattr
(
layer
,
"input_scale"
,
None
),
getattr
(
layer
,
"input_scale"
,
None
),
)
)
)
if
self
.
act_q_static
:
if
self
.
act_q_static
:
assert
input_scale
is
not
None
assert
input_scale
is
not
None
input_scale
=
input_scale
.
max
()
input_scale
=
input_scale
.
max
()
...
@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
class
Fp8OnlineLinearMethod
(
Fp8LinearMethod
):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHT
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
layer
.
input_scale
=
None
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# Update layer with new values.
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
if
self
.
use_marlin
:
size_k_first
=
True
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
)
# Activations not quantized for marlin.
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment