Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4ee2c3d
Unverified
Commit
f4ee2c3d
authored
Dec 18, 2025
by
Vasiliy Kuznetsov
Committed by
GitHub
Dec 18, 2025
Browse files
fix fp8 online quantization streaming with tp > 1 (#30900)
Signed-off-by:
vasiliy
<
vasiliy@fb.com
>
parent
9a5e9652
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
8 deletions
+33
-8
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+33
-8
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
f4ee2c3d
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.utils._python_dispatch
import
TorchDispatchMode
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
...
@@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig):
...
@@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig):
return
None
return
None
class
CopyNumelCounter
(
TorchDispatchMode
):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
copied_numel
=
0
def
__torch_dispatch__
(
self
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
kwargs
=
{}
out
=
func
(
*
args
,
**
kwargs
)
if
func
==
torch
.
ops
.
aten
.
copy_
.
default
:
self
.
copied_numel
+=
args
[
0
].
numel
()
return
out
class
Fp8LinearMethod
(
LinearMethodBase
):
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
...
@@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# track how many elements we have updated
# track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# if we have loaded all of the elements, call
# process_weights_after_loading
# process_weights_after_loading
...
@@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
new_extra_weight_attrs
=
extra_weight_attrs
new_extra_weight_attrs
=
extra_weight_attrs
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
def
patched_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# load the current weight chunk
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
# add a counter to track how many elements we have updated
# add a counter to track how many elements we have updated
if
not
hasattr
(
layer
,
"_loaded_numel"
):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
=
0
layer
.
_loaded_numel
+=
loaded_weight
.
numel
()
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
res
=
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
# type: ignore[misc]
layer
.
_loaded_numel
+=
copy_numel_counter
.
copied_numel
# if we have loaded all of the elements, call
# if we have loaded all of the elements, call
# process_weights_after_loading
# process_weights_after_loading
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment