Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d28d86e8
Unverified
Commit
d28d86e8
authored
Mar 29, 2026
by
Kyle Sayers
Committed by
GitHub
Mar 29, 2026
Browse files
[QeRL] Fix online quantized reloading (#38442)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
995dea13
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
104 additions
and
62 deletions
+104
-62
.buildkite/test-amd.yaml
.buildkite/test-amd.yaml
+3
-3
.buildkite/test_areas/model_executor.yaml
.buildkite/test_areas/model_executor.yaml
+1
-1
.buildkite/test_areas/models_distributed.yaml
.buildkite/test_areas/models_distributed.yaml
+1
-1
tests/model_executor/model_loader/test_reload.py
tests/model_executor/model_loader/test_reload.py
+45
-20
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+5
-2
vllm/model_executor/model_loader/reload/layerwise.py
vllm/model_executor/model_loader/reload/layerwise.py
+15
-3
vllm/model_executor/model_loader/reload/meta.py
vllm/model_executor/model_loader/reload/meta.py
+5
-4
vllm/model_executor/model_loader/reload/types.py
vllm/model_executor/model_loader/reload/types.py
+12
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+17
-21
No files found.
.buildkite/test-amd.yaml
View file @
d28d86e8
...
...
@@ -812,7 +812,7 @@ steps:
commands
:
-
apt-get update && apt-get install -y curl libsodium23
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
pytest -v -s model_executor
-
pytest -v -s model_executor
-m '(not slow_test)'
-
pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
...
...
@@ -1242,7 +1242,7 @@ steps:
-
vllm/platforms/rocm.py
commands
:
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-m '(not slow_test)'
-
pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
-
pytest models/language -v -s -m 'distributed(num_gpus=2)'
-
pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
...
...
@@ -2501,7 +2501,7 @@ steps:
-
tests/models/
commands
:
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-m '(not slow_test)'
-
pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
-
pytest models/language -v -s -m 'distributed(num_gpus=2)'
-
pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
...
...
.buildkite/test_areas/model_executor.yaml
View file @
d28d86e8
...
...
@@ -13,5 +13,5 @@ steps:
commands
:
-
apt-get update && apt-get install -y curl libsodium23
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
pytest -v -s model_executor
-
pytest -v -s model_executor
-m '(not slow_test)'
-
pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
.buildkite/test_areas/models_distributed.yaml
View file @
d28d86e8
...
...
@@ -14,7 +14,7 @@ steps:
-
tests/models/
commands
:
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-
CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
-m '(not slow_test)'
# Avoid importing model tests that cause CUDA reinitialization error
-
pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
-
pytest models/language -v -s -m 'distributed(num_gpus=2)'
...
...
tests/model_executor/model_loader/test_reload.py
View file @
d28d86e8
...
...
@@ -38,7 +38,10 @@ def test_move_metatensors():
def
test_reload_lifecycle
():
layer
=
torch
.
nn
.
Linear
(
2
,
3
)
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
))
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
),
restore_device
=
torch
.
device
(
"cpu"
),
)
restore_layer_on_meta
(
layer
,
info
)
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
...
...
@@ -48,7 +51,7 @@ def test_reload_lifecycle():
assert
tensor
.
__class__
==
meta_tensor
.
__class__
assert
tensor
.
__dict__
==
meta_tensor
.
__dict__
materialize_layer
(
layer
)
materialize_layer
(
layer
,
info
)
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
materialized_tensor
=
getattr
(
layer
,
name
)
assert
tensor
.
dtype
==
materialized_tensor
.
dtype
...
...
@@ -60,7 +63,10 @@ def test_reload_lifecycle():
def
test_model_cleanup
(
dist_init
,
default_vllm_config
):
layer
=
QKVParallelLinear
(
2
,
3
,
4
)
assert
layer
.
weight
.
weight_loader
.
__self__
is
layer
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
))
info
=
LayerReloadingInfo
(
restore_metadata
=
capture_layer_to_meta
(
layer
),
restore_device
=
torch
.
device
(
"cpu"
),
)
mock_info_dict
:
WeakKeyDictionary
[
torch
.
nn
.
Module
,
LayerReloadingInfo
]
=
(
WeakKeyDictionary
()
...
...
@@ -90,39 +96,46 @@ def test_get_numel_loaded():
assert
ret
==
"value"
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
pytest
.
param
(
1
),
pytest
.
param
(
2
,
marks
=
[
pytest
.
mark
.
slow_test
])]
)
@
pytest
.
mark
.
parametrize
(
"base_model,mul_model,add_model"
,
[
(
pytest
.
param
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
(
pytest
.
param
(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK"
,
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK"
,
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
(
pytest
.
param
(
"inference-optimization/Qwen3-0.6B-W4A16-G128"
,
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128"
,
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
(
pytest
.
param
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
(
pytest
.
param
(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC"
,
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC"
,
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC"
,
),
(
pytest
.
param
(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16"
,
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16"
,
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
],
)
...
...
@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
tensor_parallel_size
=
tp_size
,
enable_expert_parallel
=
(
tp_size
>
1
and
"DeepSeek"
in
base_model
),
enable_prefix_caching
=
False
,
max_model_len
=
16
,
max_num_seqs
=
1
,
)
as
llm
:
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
mul_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
...
...
@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert
add_perp
<
mul_perp
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
pytest
.
param
(
1
),
pytest
.
param
(
2
,
marks
=
[
pytest
.
mark
.
slow_test
])]
)
@
pytest
.
mark
.
parametrize
(
"base_model,mul_model,add_model,quantization"
,
[
(
pytest
.
param
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"fp8"
,
),
(
pytest
.
param
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
"fp8"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
(
pytest
.
param
(
"Qwen/Qwen3-0.6B"
,
"inference-optimization/Qwen3-0.6B-debug-multiply"
,
"inference-optimization/Qwen3-0.6B-debug-add"
,
"mxfp8"
,
marks
=
[
pytest
.
mark
.
slow_test
],
),
pytest
.
param
(
"inference-optimization/DeepSeek-V3-debug-empty"
,
"inference-optimization/DeepSeek-V3-debug-multiply"
,
"inference-optimization/DeepSeek-V3-debug-add"
,
"mxfp8"
,
marks
=
[
pytest
.
mark
.
slow_test
,
pytest
.
mark
.
xfail
(
reason
=
"mxfp4 & mla is not supported yet"
),
],
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def
test_online_quantize_reload
(
...
...
@@ -195,6 +218,8 @@ def test_online_quantize_reload(
tensor_parallel_size
=
tp_size
,
enable_expert_parallel
=
(
tp_size
>
1
and
"DeepSeek"
in
base_model
),
enable_prefix_caching
=
False
,
max_model_len
=
16
,
max_num_seqs
=
1
,
)
as
llm
:
llm
.
collective_rpc
(
"reload_weights"
,
kwargs
=
{
"weights_path"
:
mul_model
})
mul_perp
=
llm
.
generate_prompt_perplexity
([
"3 4 = 12"
],
mask
=
[
"3 4 ="
])[
0
]
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
d28d86e8
...
...
@@ -1006,14 +1006,17 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
initialize_online_processing
(
layer
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
return
fp8_dtype
=
current_platform
.
fp8_dtype
()
w13
=
torch
.
empty_like
(
layer
.
w13_weight
,
dtype
=
fp8_dtype
)
w2
=
torch
.
empty_like
(
layer
.
w2_weight
,
dtype
=
fp8_dtype
)
w13_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
layer
.
num_experts
,
dtype
=
torch
.
float32
)
w13_scale
=
torch
.
ones
(
layer
.
num_experts
,
device
=
w13
.
device
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
ones
(
layer
.
num_experts
,
device
=
w2
.
device
,
dtype
=
torch
.
float32
)
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
...
...
vllm/model_executor/model_loader/reload/layerwise.py
View file @
d28d86e8
...
...
@@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
information existed, a new entry is constructed
"""
if
layer
not
in
LAYERWISE_INFO
:
LAYERWISE_INFO
[
layer
]
=
LayerReloadingInfo
()
LAYERWISE_INFO
[
layer
]
=
LayerReloadingInfo
(
restore_metadata
=
({},
{}),
restore_device
=
torch
.
get_default_device
(),
)
return
LAYERWISE_INFO
[
layer
]
...
...
@@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):
for
layer
in
model
.
modules
():
info
=
get_layerwise_info
(
layer
)
info
.
restore_metadata
=
capture_layer_to_meta
(
layer
)
info
.
restore_device
=
torch
.
get_default_device
()
@
torch
.
no_grad
()
...
...
@@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta
(
layer
,
info
)
# Wrap weight loaders to buffer loading
initialize_online_processing
(
layer
)
def
initialize_online_processing
(
layer
:
torch
.
nn
.
Module
):
"""
Wrap a layer's weight loaders with online processing loaders.
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading
:param layer: layer whose parameter weight loaders will be wrapped
"""
info
=
get_layerwise_info
(
layer
)
# Track loading progress to determine when to process/copy
...
...
@@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif
info
.
load_numel
<=
0
:
# first load but received no weights. This happens on dummy load
if
info
.
kernel_tensors
is
None
:
materialize_layer
(
layer
)
materialize_layer
(
layer
,
info
)
# reloading: place kernel tensors back as a fallback
else
:
...
...
@@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer
(
layer
)
materialize_layer
(
layer
,
info
)
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
...
...
vllm/model_executor/model_loader/reload/meta.py
View file @
d28d86e8
...
...
@@ -94,11 +94,12 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
layer
.
register_buffer
(
name
,
buffer
)
def
materialize_layer
(
layer
:
torch
.
nn
.
Module
)
->
None
:
def
materialize_layer
(
layer
:
torch
.
nn
.
Module
,
info
:
LayerReloadingInfo
)
:
"""Materialize all meta tensors in a layer to actual tensors."""
if
layer
.
__class__
.
__name__
in
SKIP_MODULES
:
return
with
info
.
restore_device
:
for
name
,
tensor
in
get_layer_tensors
(
layer
).
items
():
if
name
not
in
SKIP_TENSORS
:
setattr
(
layer
,
name
,
materialize_meta_tensor
(
tensor
))
...
...
vllm/model_executor/model_loader/reload/types.py
View file @
d28d86e8
...
...
@@ -13,21 +13,26 @@ LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@
dataclass
class
LayerReloadingInfo
:
# model format
(
meta
), populat
ed by `record_metadata_for_reloading`
restore_metadata
:
LayerTensors
=
field
(
default_factory
=
lambda
:
({},
{}))
# model format meta
data, record
ed by `record_metadata_for_reloading`
restore_metadata
:
LayerTensors
#
kernel format (device), used to copy into when
reloading
only
kernel_tensors
:
LayerTensors
|
None
=
Non
e
#
device to materialize layers with, recorded by `record_metadata_for_
reloading
`
restore_device
:
torch
.
devic
e
# track how many
restored
elements are ready for loading
# track how many elements are ready for loading
, used by `online_process_loader`
load_numel
:
int
=
0
load_numel_total
:
int
|
None
=
None
#
stores argument
s and tensors ready
for
load
ing
#
used by `online_process_loader` to buffer arg
s and tensors
until
ready
to
load
loaded_weights
:
list
[
tuple
[
str
,
BoundArguments
]]
=
field
(
default_factory
=
list
)
# kernel formatted tensors, copied into by `_layerwise_process` when reloading
kernel_tensors
:
LayerTensors
|
None
=
None
def
reset
(
self
):
self
.
__init__
(
restore_metadata
=
self
.
restore_metadata
)
# type: ignore[misc]
self
.
__init__
(
# type: ignore[misc]
restore_metadata
=
self
.
restore_metadata
,
restore_device
=
self
.
restore_device
)
def
can_load
(
self
)
->
bool
:
return
self
.
load_numel_total
is
not
None
vllm/v1/worker/gpu_model_runner.py
View file @
d28d86e8
...
...
@@ -4943,10 +4943,6 @@ class GPUModelRunner(
# begin loading weights
logger
.
info_once
(
"Reloading weights inplace..."
,
scope
=
"local"
)
load_device
=
(
self
.
vllm_config
.
load_config
.
device
or
self
.
vllm_config
.
device_config
.
device
)
with
torch
.
device
(
load_device
):
if
is_checkpoint_format
:
# load weights from checkpoint/ original model format
initialize_layerwise_reload
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment