Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0d402d26
Unverified
Commit
0d402d26
authored
Dec 08, 2025
by
Vasiliy Kuznetsov
Committed by
GitHub
Dec 08, 2025
Browse files
online fp8 quant with streaming weight post-processing (#29196)
Signed-off-by:
vasiliy
<
vasiliy@fb.com
>
parent
d1b5e7af
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
1 deletion
+66
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+66
-1
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
0d402d26
...
...
@@ -465,6 +465,30 @@ class Fp8LinearMethod(LinearMethodBase):
output_size_per_partition
,
input_size_per_partition
,
weight_loader
)
else
:
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
# For non-serialized checkpoints, use original dtype
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
...
...
@@ -474,7 +498,7 @@ class Fp8LinearMethod(LinearMethodBase):
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
weight_loader
=
patched_
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
...
...
@@ -515,6 +539,9 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
size_k_first
=
True
input_scale
=
None
# TODO(rob): refactor block quant into separate class.
...
...
@@ -738,6 +765,41 @@ class Fp8MoEMethod(FusedMoEMethodBase):
f
"weight quantization block_k =
{
block_k
}
."
)
# if we are doing online quantization, patch the weight
# loaded to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel
=
layer
.
w13_weight
.
numel
()
+
layer
.
w2_weight
.
numel
()
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
extra_weight_attrs
=
new_extra_weight_attrs
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -839,6 +901,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
rocm_aiter_moe_enabled
=
False
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# Lazy import to avoid importing triton too early.
self
.
rocm_aiter_moe_enabled
=
rocm_aiter_ops
.
is_fused_moe_enabled
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment