Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0130223b
Unverified
Commit
0130223b
authored
Feb 02, 2026
by
Vasiliy Kuznetsov
Committed by
GitHub
Feb 02, 2026
Browse files
fix memory for online fp8 quantization with streaming weight load (#31914)
Signed-off-by:
vasiliy
<
vasiliy@fb.com
>
parent
5d1aef30
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
282 additions
and
37 deletions
+282
-37
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+96
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+119
-4
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+13
-0
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+1
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+53
-32
No files found.
tests/quantization/test_fp8.py
View file @
0130223b
...
...
@@ -5,7 +5,10 @@
Run `pytest tests/quantization/test_fp8.py --forked`.
"""
import
logging
import
pytest
import
regex
as
re
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
...
...
@@ -195,6 +198,99 @@ def test_online_quantization(
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
)
def
test_online_quant_peak_mem
(
vllm_runner
,
caplog_mp_spawn
,
monkeypatch
,
)
->
None
:
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
# 1. it covers both Linear and MoE paths
# 2. it is already used by other tests in CI, so adding it here
# does not increase disk space for CI runners
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
# 1.3 GiB fp8), but could not as adding one more model makes CI
# run out of disk space.
model_name
=
"allenai/OLMoE-1B-7B-0125-Instruct"
# Force spawn to ensure caplog_mp_spawn works consistently
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
monkeypatch
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
with
(
caplog_mp_spawn
(
logging
.
DEBUG
)
as
log_holder
,
vllm_runner
(
model_name
,
quantization
=
"fp8"
,
enforce_eager
=
True
,
)
as
llm
,
):
outputs
=
llm
.
generate_greedy
([
"The future of AI is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
log_text
=
log_holder
.
text
# Parse memory usage from captured logs
model_memory_gib
=
None
peak_memory_gib
=
None
for
line
in
log_text
.
splitlines
():
if
model_memory_gib
is
None
:
match
=
re
.
search
(
r
"Model loading took ([\d.]+) GiB memory"
,
line
)
if
match
:
model_memory_gib
=
float
(
match
.
group
(
1
))
if
peak_memory_gib
is
None
:
match
=
re
.
search
(
r
"Peak GPU memory after loading weights: ([\d.]+) GiB"
,
line
)
if
match
:
peak_memory_gib
=
float
(
match
.
group
(
1
))
assert
model_memory_gib
is
not
None
,
"Could not find model loading memory log"
assert
peak_memory_gib
is
not
None
,
"Could not find peak memory log"
print
(
f
"GPU memory used after loading weights:
{
model_memory_gib
}
GiB"
)
print
(
f
"Peak GPU memory usage while loading weights:
{
peak_memory_gib
}
GiB"
)
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib
=
6.7
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
expected_peak_memory_gib
=
expected_model_memory_gib
*
1.4
assert
model_memory_gib
<
expected_model_memory_gib
,
(
f
"
{
model_memory_gib
=
}
higher than
{
expected_model_memory_gib
}
"
)
assert
peak_memory_gib
<
expected_peak_memory_gib
,
(
f
"
{
peak_memory_gib
=
}
higher than
{
expected_peak_memory_gib
}
"
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
)
def
test_online_quant_load_format_dummy
(
vllm_runner
,
monkeypatch
,
caplog
,
)
->
None
:
with
vllm_runner
(
"ibm-granite/granite-3.0-1b-a400m-base"
,
quantization
=
"fp8"
,
enforce_eager
=
True
,
load_format
=
"dummy"
,
)
as
llm
:
outputs
=
llm
.
generate_greedy
([
"The future of AI is"
],
max_tokens
=
4
)
print
(
outputs
[
0
][
1
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0130223b
...
...
@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
vllm.model_executor.model_loader.weight_utils
import
initialize_single_dummy_weight
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
ModelWeightParameter
,
...
...
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
return
out
def
_copy_missing_attrs
(
old
:
torch
.
Tensor
,
new
:
torch
.
Tensor
)
->
None
:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs
=
set
(
dir
(
new
))
attrs_to_set
=
{}
for
attr
in
dir
(
old
):
if
attr
not
in
new_attrs
:
attrs_to_set
[
attr
]
=
getattr
(
old
,
attr
)
set_weight_attrs
(
new
,
attrs_to_set
)
class
Fp8LinearMethod
(
LinearMethodBase
):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
...
...
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param
=
layer
.
weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
...
...
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return
res
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
patched_weight_loader
,
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
weight
.
device
==
torch
.
device
(
"meta"
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty_like
(
layer
.
weight
,
device
=
layer
.
_load_device
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
layer
.
weight
.
weight_loader
,
)
_copy_missing_attrs
(
layer
.
weight
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
initialize_single_dummy_weight
(
layer
.
weight
)
# TODO(future): support block_quant in online quant path
assert
not
self
.
block_quant
...
...
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
not
hasattr
(
layer
,
"_loaded_numel"
):
layer
.
_loaded_numel
=
0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer
.
_w13_weight_orig_id
=
id
(
layer
.
w13_weight
)
layer
.
_w2_weight_orig_id
=
id
(
layer
.
w2_weight
)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
del
layer
.
_load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if
id
(
param
)
==
layer
.
_w13_weight_orig_id
:
param
=
layer
.
w13_weight
elif
id
(
param
)
==
layer
.
_w2_weight_orig_id
:
param
=
layer
.
w2_weight
# load the current weight chunk
copy_numel_counter
=
CopyNumelCounter
()
with
copy_numel_counter
:
...
...
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
layer
.
_loaded_numel
==
target_loaded_numel
:
self
.
process_weights_after_loading
(
layer
)
# Delete the bookkeeping
del
layer
.
_loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer
.
_already_called_process_weights_after_loading
=
True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return
res
new_extra_weight_attrs
[
"weight_loader"
]
=
patched_weight_loader
...
...
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
...
...
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
# materialized just-in-time in `patched_weight_loader`
device
=
"meta"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# stash the correct device for `patched_weight_loader`
layer
.
_load_device
=
torch
.
get_default_device
()
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
...
...
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if
layer
.
w13_weight
.
device
==
torch
.
device
(
"meta"
):
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w13_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w13_weight
,
{
"weight_loader"
:
layer
.
w13_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w13_weight
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
initialize_single_dummy_weight
(
layer
.
w13_weight
)
if
layer
.
w2_weight
.
device
==
torch
.
device
(
"meta"
):
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty_like
(
layer
.
w2_weight
,
device
=
layer
.
_load_device
),
requires_grad
=
False
,
)
set_weight_attrs
(
w2_weight
,
{
"weight_loader"
:
layer
.
w2_weight
.
weight_loader
}
)
_copy_missing_attrs
(
layer
.
w2_weight
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
initialize_single_dummy_weight
(
layer
.
w2_weight
)
# If checkpoint is fp16, quantize in place.
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
...
...
vllm/model_executor/model_loader/base_loader.py
View file @
0130223b
...
...
@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import (
initialize_model
,
process_weights_after_loading
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.mem_utils
import
format_gib
from
vllm.utils.torch_utils
import
set_default_torch_dtype
logger
=
init_logger
(
__name__
)
...
...
@@ -56,6 +58,17 @@ class BaseModelLoader(ABC):
logger
.
debug
(
"Loading weights on %s ..."
,
load_device
)
# Quantization does not happen in `load_weights` but after it
self
.
load_weights
(
model
,
model_config
)
# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if
current_platform
.
is_cuda
():
peak_memory
=
torch
.
cuda
.
max_memory_allocated
()
logger
.
debug_once
(
"Peak GPU memory after loading weights: %s GiB"
,
format_gib
(
peak_memory
),
scope
=
"local"
,
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
...
...
vllm/model_executor/model_loader/dummy_loader.py
View file @
0130223b
...
...
@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader):
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
,
model_config
)
vllm/model_executor/model_loader/weight_utils.py
View file @
0130223b
...
...
@@ -1059,6 +1059,7 @@ def composed_weight_loader(
def
initialize_dummy_weights
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
...
...
@@ -1075,7 +1076,27 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
# TODO(future PR): make the check below more generic as more online
# quant backends are added
is_fp8_py_quant
=
model_config
.
quantization
==
"fp8"
for
param
in
model
.
state_dict
().
values
():
if
is_fp8_py_quant
and
param
.
device
==
torch
.
device
(
"meta"
):
# for fp8.py's online quantization, dummy weight init will happen
# in `process_weights_after_loading`.
# TODO(future PR): consider refactoring dummy model init to compose
# better with online quantization
continue
initialize_single_dummy_weight
(
param
,
low
,
high
,
seed
)
def
initialize_single_dummy_weight
(
param
:
torch
.
Tensor
,
low
:
float
=
-
1e-3
,
high
:
float
=
1e-3
,
seed
:
int
=
1234
,
)
->
None
:
if
torch
.
is_floating_point
(
param
):
if
current_platform
.
is_tpu
():
generator
=
torch
.
Generator
(
device
=
"cpu"
)
...
...
@@ -1098,7 +1119,7 @@ def initialize_dummy_weights(
+
low
)
torch
.
_sync
(
param
)
continue
return
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment