Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2389c12
Unverified
Commit
d2389c12
authored
Jan 20, 2026
by
Vasiliy Kuznetsov
Committed by
GitHub
Jan 20, 2026
Browse files
fp8 online quant: split out Fp8OnlineLinearMethod (#32189)
parent
22375f8d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
110 deletions
+143
-110
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+4
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+139
-109
No files found.
tests/quantization/test_fp8.py
View file @
d2389c12
...
...
@@ -133,7 +133,7 @@ def test_kv_cache_model_load_and_run(
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
]
)
def
test_
load_fp16_model
(
def
test_
online_quantization
(
vllm_runner
,
kv_cache_dtype
:
str
,
force_marlin
:
bool
,
...
...
@@ -191,6 +191,9 @@ def test_load_fp16_model(
llm
.
apply_model
(
check_model
)
outputs
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
d2389c12
...
...
@@ -230,9 +230,14 @@ class Fp8Config(QuantizationConfig):
fused_mapping
=
self
.
packed_modules_mapping
,
):
return
UnquantizedLinearMethod
()
quant_method
=
Fp8LinearMethod
(
self
)
quant_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
quant_method
if
not
self
.
is_checkpoint_fp8_serialized
:
online_method
=
Fp8OnlineLinearMethod
(
self
)
online_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
online_method
else
:
offline_method
=
Fp8LinearMethod
(
self
)
offline_method
.
marlin_input_dtype
=
get_marlin_input_dtype
(
prefix
)
return
offline_method
elif
isinstance
(
layer
,
FusedMoE
):
if
is_layer_skipped
(
prefix
=
prefix
,
...
...
@@ -295,13 +300,8 @@ class Fp8LinearMethod(LinearMethodBase):
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
1. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
...
...
@@ -388,54 +388,11 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
weight_block_size
,
)
# WEIGHT
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
weight
=
create_fp8_weight_parameter
(
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
else
:
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
# For non-serialized checkpoints, use original dtype
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
if
not
self
.
block_quant
:
scale
=
create_fp8_scale_parameter
(
...
...
@@ -468,9 +425,6 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
size_k_first
=
True
input_scale
=
None
# TODO(rob): refactor block quant into separate class.
...
...
@@ -488,27 +442,20 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint not serialized fp8, quantize the weights.
else
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# shards in a fused module
else
:
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
weight
,
weight_scale
,
input_scale
=
(
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
layer
.
logical_widths
,
getattr
(
layer
,
"input_scale"
,
None
),
)
)
if
self
.
act_q_static
:
assert
input_scale
is
not
None
input_scale
=
input_scale
.
max
()
...
...
@@ -607,6 +554,89 @@ class Fp8LinearMethod(LinearMethodBase):
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
class
Fp8OnlineLinearMethod
(
Fp8LinearMethod
):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
# WEIGHT
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
layer
.
input_scale
=
None
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# Update layer with new values.
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
if
self
.
use_marlin
:
size_k_first
=
True
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
)
# Activations not quantized for marlin.
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment