Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8632e831
Unverified
Commit
8632e831
authored
Jul 13, 2025
by
22quinn
Committed by
GitHub
Jul 14, 2025
Browse files
[Core] Add `update_config` RPC method (#20095)
Signed-off-by:
22quinn
<
33176974+22quinn@users.noreply.github.com
>
parent
4bbfc36b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
97 additions
and
9 deletions
+97
-9
tests/test_config.py
tests/test_config.py
+29
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+14
-2
vllm/config.py
vllm/config.py
+20
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+15
-2
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+4
-1
No files found.
tests/test_config.py
View file @
8632e831
...
...
@@ -7,7 +7,7 @@ import pytest
from
vllm.compilation.backends
import
VllmBackend
from
vllm.config
import
(
LoadConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
,
get_field
)
get_field
,
update_config
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
...
...
@@ -46,6 +46,34 @@ def test_get_field():
assert
c
.
default_factory
is
MISSING
@
dataclass
class
_TestNestedConfig
:
a
:
_TestConfigFields
=
field
(
default_factory
=
lambda
:
_TestConfigFields
(
a
=
0
))
def
test_update_config
():
# Simple update
config1
=
_TestConfigFields
(
a
=
0
)
new_config1
=
update_config
(
config1
,
{
"a"
:
42
})
assert
new_config1
.
a
==
42
# Nonexistent field
with
pytest
.
raises
(
AssertionError
):
new_config1
=
update_config
(
config1
,
{
"nonexistent"
:
1
})
# Nested update with dataclass
config2
=
_TestNestedConfig
()
new_inner_config
=
_TestConfigFields
(
a
=
1
,
c
=
"new_value"
)
new_config2
=
update_config
(
config2
,
{
"a"
:
new_inner_config
})
assert
new_config2
.
a
==
new_inner_config
# Nested update with dict
config3
=
_TestNestedConfig
()
new_config3
=
update_config
(
config3
,
{
"a"
:
{
"c"
:
"new_value"
}})
assert
new_config3
.
a
.
c
==
"new_value"
# Nested update with invalid type
with
pytest
.
raises
(
AssertionError
):
new_config3
=
update_config
(
config3
,
{
"a"
:
"new_value"
})
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
8632e831
...
...
@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
def
test_update_config
(
model_runner
):
# Simple update
model_runner
.
update_config
({
"load_config"
:
{
"load_format"
:
"dummy"
}})
assert
model_runner
.
load_config
.
load_format
==
"dummy"
# Raise error on non-existing config
with
pytest
.
raises
(
AssertionError
):
model_runner
.
update_config
({
"do_not_exist_config"
:
"dummy"
})
def
test_load_model_weights_inplace
(
dist_init
,
model_runner
,
model_runner_2
):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner
.
load_model
()
original_load_format
=
model_runner_2
.
load_config
.
load_format
model_runner_2
.
load_config
.
load_format
=
"dummy"
model_runner_2
.
update_config
({
"
load_config
"
:
{
"
load_format
"
:
"dummy"
}})
model_runner_2
.
load_model
()
# Initial model loading with dummy weights
assert
str
(
model_runner
.
get_model
().
state_dict
())
!=
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
load_config
.
load_format
=
original_load_format
model_runner_2
.
update_config
(
{
"load_config"
:
{
"load_format"
:
original_load_format
}})
model_runner_2
.
load_model
()
# Load real weights inplace
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
model_runner_2
.
get_model
().
state_dict
())
...
...
vllm/config.py
View file @
8632e831
...
...
@@ -71,6 +71,7 @@ if TYPE_CHECKING:
ConfigType
=
type
[
DataclassInstance
]
HfOverrides
=
Union
[
dict
,
Callable
[[
type
],
type
]]
else
:
DataclassInstance
=
Any
PlacementGroup
=
Any
PretrainedConfig
=
Any
ExecutorBase
=
Any
...
...
@@ -87,7 +88,7 @@ else:
"vllm.model_executor.models"
)
logger
=
init_logger
(
__name__
)
DataclassInstanceT
=
TypeVar
(
"DataclassInstanceT"
,
bound
=
DataclassInstance
)
ConfigT
=
TypeVar
(
"ConfigT"
,
bound
=
ConfigType
)
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
...
...
@@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
@
property
def
allow_audio_chunking
(
self
)
->
bool
:
return
self
.
min_energy_split_window_size
is
not
None
def
update_config
(
config
:
DataclassInstanceT
,
overrides
:
dict
[
str
,
Any
])
->
DataclassInstanceT
:
processed_overrides
=
{}
for
field_name
,
value
in
overrides
.
items
():
assert
hasattr
(
config
,
field_name
),
f
"
{
type
(
config
)
}
has no field `
{
field_name
}
`"
current_value
=
getattr
(
config
,
field_name
)
if
is_dataclass
(
current_value
)
and
not
is_dataclass
(
value
):
assert
isinstance
(
value
,
dict
),
(
f
"Overrides to
{
type
(
config
)
}
.
{
field_name
}
must be a dict"
f
" or
{
type
(
current_value
)
}
, but got
{
type
(
value
)
}
"
)
value
=
update_config
(
current_value
,
# type: ignore[type-var]
value
)
processed_overrides
[
field_name
]
=
value
return
replace
(
config
,
**
processed_overrides
)
vllm/v1/worker/gpu_model_runner.py
View file @
8632e831
...
...
@@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.attention.layer
import
Attention
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
get_layers_from_vllm_config
,
update_config
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
...
...
@@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
.
append
(
drafter_output
.
tolist
())
return
draft_token_ids
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
allowed_config_names
=
{
"load_config"
,
"model_config"
}
for
config_name
,
config_overrides
in
overrides
.
items
():
assert
config_name
in
allowed_config_names
,
\
f
"Config `
{
config_name
}
` not supported. "
\
f
"Allowed configs:
{
allowed_config_names
}
"
config
=
getattr
(
self
,
config_name
)
new_config
=
update_config
(
config
,
config_overrides
)
setattr
(
self
,
config_name
,
new_config
)
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
...
...
vllm/v1/worker/gpu_worker.py
View file @
8632e831
...
...
@@ -4,7 +4,7 @@
import
copy
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch.distributed
...
...
@@ -193,6 +193,9 @@ class Worker(WorkerBase):
with
context
:
self
.
model_runner
.
load_model
()
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
self
.
model_runner
.
update_config
(
overrides
)
@
torch
.
inference_mode
()
def
determine_available_memory
(
self
)
->
int
:
"""Profiles the peak memory usage of the model to determine how much
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
8632e831
...
...
@@ -3,7 +3,7 @@
import
bisect
import
gc
import
time
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
cast
from
unittest.mock
import
patch
import
numpy
as
np
...
...
@@ -18,7 +18,8 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
(
ParallelConfig
,
VllmConfig
,
get_layers_from_vllm_config
,
update_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
...
...
@@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return
model_runner_output
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names
=
{
"load_config"
,
"model_config"
}
for
config_name
,
config_overrides
in
overrides
.
items
():
assert
config_name
in
allowed_config_names
,
\
f
"Config `
{
config_name
}
` not supported. "
\
f
"Allowed configs:
{
allowed_config_names
}
"
config
=
getattr
(
self
,
config_name
)
new_config
=
update_config
(
config
,
config_overrides
)
setattr
(
self
,
config_name
,
new_config
)
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
...
...
vllm/v1/worker/tpu_worker.py
View file @
8632e831
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import
os
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch.distributed
...
...
@@ -260,6 +260,9 @@ class TPUWorker:
def
load_model
(
self
)
->
None
:
self
.
model_runner
.
load_model
()
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
self
.
model_runner
.
update_config
(
overrides
)
def
compile_or_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment