Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8632e831
Unverified
Commit
8632e831
authored
Jul 13, 2025
by
22quinn
Committed by
GitHub
Jul 14, 2025
Browse files
[Core] Add `update_config` RPC method (#20095)
Signed-off-by:
22quinn
<
33176974+22quinn@users.noreply.github.com
>
parent
4bbfc36b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
97 additions
and
9 deletions
+97
-9
tests/test_config.py
tests/test_config.py
+29
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+14
-2
vllm/config.py
vllm/config.py
+20
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+11
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-1
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+15
-2
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+4
-1
No files found.
tests/test_config.py
View file @
8632e831
...
@@ -7,7 +7,7 @@ import pytest
...
@@ -7,7 +7,7 @@ import pytest
from
vllm.compilation.backends
import
VllmBackend
from
vllm.compilation.backends
import
VllmBackend
from
vllm.config
import
(
LoadConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
,
from
vllm.config
import
(
LoadConfig
,
ModelConfig
,
PoolerConfig
,
VllmConfig
,
get_field
)
get_field
,
update_config
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -46,6 +46,34 @@ def test_get_field():
...
@@ -46,6 +46,34 @@ def test_get_field():
assert
c
.
default_factory
is
MISSING
assert
c
.
default_factory
is
MISSING
@
dataclass
class
_TestNestedConfig
:
a
:
_TestConfigFields
=
field
(
default_factory
=
lambda
:
_TestConfigFields
(
a
=
0
))
def
test_update_config
():
# Simple update
config1
=
_TestConfigFields
(
a
=
0
)
new_config1
=
update_config
(
config1
,
{
"a"
:
42
})
assert
new_config1
.
a
==
42
# Nonexistent field
with
pytest
.
raises
(
AssertionError
):
new_config1
=
update_config
(
config1
,
{
"nonexistent"
:
1
})
# Nested update with dataclass
config2
=
_TestNestedConfig
()
new_inner_config
=
_TestConfigFields
(
a
=
1
,
c
=
"new_value"
)
new_config2
=
update_config
(
config2
,
{
"a"
:
new_inner_config
})
assert
new_config2
.
a
==
new_inner_config
# Nested update with dict
config3
=
_TestNestedConfig
()
new_config3
=
update_config
(
config3
,
{
"a"
:
{
"c"
:
"new_value"
}})
assert
new_config3
.
a
.
c
==
"new_value"
# Nested update with invalid type
with
pytest
.
raises
(
AssertionError
):
new_config3
=
update_config
(
config3
,
{
"a"
:
"new_value"
})
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
(
"model_id"
,
"expected_runner_type"
,
"expected_task"
),
[
[
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
8632e831
...
@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
...
@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
def
test_update_config
(
model_runner
):
# Simple update
model_runner
.
update_config
({
"load_config"
:
{
"load_format"
:
"dummy"
}})
assert
model_runner
.
load_config
.
load_format
==
"dummy"
# Raise error on non-existing config
with
pytest
.
raises
(
AssertionError
):
model_runner
.
update_config
({
"do_not_exist_config"
:
"dummy"
})
def
test_load_model_weights_inplace
(
dist_init
,
model_runner
,
model_runner_2
):
def
test_load_model_weights_inplace
(
dist_init
,
model_runner
,
model_runner_2
):
# In this test, model_runner loads model + weights in one go, while
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner
.
load_model
()
model_runner
.
load_model
()
original_load_format
=
model_runner_2
.
load_config
.
load_format
original_load_format
=
model_runner_2
.
load_config
.
load_format
model_runner_2
.
load_config
.
load_format
=
"dummy"
model_runner_2
.
update_config
({
"
load_config
"
:
{
"
load_format
"
:
"dummy"
}})
model_runner_2
.
load_model
()
# Initial model loading with dummy weights
model_runner_2
.
load_model
()
# Initial model loading with dummy weights
assert
str
(
model_runner
.
get_model
().
state_dict
())
!=
str
(
assert
str
(
model_runner
.
get_model
().
state_dict
())
!=
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
load_config
.
load_format
=
original_load_format
model_runner_2
.
update_config
(
{
"load_config"
:
{
"load_format"
:
original_load_format
}})
model_runner_2
.
load_model
()
# Load real weights inplace
model_runner_2
.
load_model
()
# Load real weights inplace
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
get_model
().
state_dict
())
...
...
vllm/config.py
View file @
8632e831
...
@@ -71,6 +71,7 @@ if TYPE_CHECKING:
...
@@ -71,6 +71,7 @@ if TYPE_CHECKING:
ConfigType
=
type
[
DataclassInstance
]
ConfigType
=
type
[
DataclassInstance
]
HfOverrides
=
Union
[
dict
,
Callable
[[
type
],
type
]]
HfOverrides
=
Union
[
dict
,
Callable
[[
type
],
type
]]
else
:
else
:
DataclassInstance
=
Any
PlacementGroup
=
Any
PlacementGroup
=
Any
PretrainedConfig
=
Any
PretrainedConfig
=
Any
ExecutorBase
=
Any
ExecutorBase
=
Any
...
@@ -87,7 +88,7 @@ else:
...
@@ -87,7 +88,7 @@ else:
"vllm.model_executor.models"
)
"vllm.model_executor.models"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
DataclassInstanceT
=
TypeVar
(
"DataclassInstanceT"
,
bound
=
DataclassInstance
)
ConfigT
=
TypeVar
(
"ConfigT"
,
bound
=
ConfigType
)
ConfigT
=
TypeVar
(
"ConfigT"
,
bound
=
ConfigType
)
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
,
"embed"
,
"classify"
,
...
@@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
...
@@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
@
property
@
property
def
allow_audio_chunking
(
self
)
->
bool
:
def
allow_audio_chunking
(
self
)
->
bool
:
return
self
.
min_energy_split_window_size
is
not
None
return
self
.
min_energy_split_window_size
is
not
None
def
update_config
(
config
:
DataclassInstanceT
,
overrides
:
dict
[
str
,
Any
])
->
DataclassInstanceT
:
processed_overrides
=
{}
for
field_name
,
value
in
overrides
.
items
():
assert
hasattr
(
config
,
field_name
),
f
"
{
type
(
config
)
}
has no field `
{
field_name
}
`"
current_value
=
getattr
(
config
,
field_name
)
if
is_dataclass
(
current_value
)
and
not
is_dataclass
(
value
):
assert
isinstance
(
value
,
dict
),
(
f
"Overrides to
{
type
(
config
)
}
.
{
field_name
}
must be a dict"
f
" or
{
type
(
current_value
)
}
, but got
{
type
(
value
)
}
"
)
value
=
update_config
(
current_value
,
# type: ignore[type-var]
value
)
processed_overrides
[
field_name
]
=
value
return
replace
(
config
,
**
processed_overrides
)
vllm/v1/worker/gpu_model_runner.py
View file @
8632e831
...
@@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
...
@@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
get_layers_from_vllm_config
,
update_config
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
has_kv_transfer_group
)
...
@@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids
.
append
(
drafter_output
.
tolist
())
draft_token_ids
.
append
(
drafter_output
.
tolist
())
return
draft_token_ids
return
draft_token_ids
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
allowed_config_names
=
{
"load_config"
,
"model_config"
}
for
config_name
,
config_overrides
in
overrides
.
items
():
assert
config_name
in
allowed_config_names
,
\
f
"Config `
{
config_name
}
` not supported. "
\
f
"Allowed configs:
{
allowed_config_names
}
"
config
=
getattr
(
self
,
config_name
)
new_config
=
update_config
(
config
,
config_overrides
)
setattr
(
self
,
config_name
,
new_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
...
...
vllm/v1/worker/gpu_worker.py
View file @
8632e831
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
copy
import
copy
import
gc
import
gc
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -193,6 +193,9 @@ class Worker(WorkerBase):
...
@@ -193,6 +193,9 @@ class Worker(WorkerBase):
with
context
:
with
context
:
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
self
.
model_runner
.
update_config
(
overrides
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
determine_available_memory
(
self
)
->
int
:
def
determine_available_memory
(
self
)
->
int
:
"""Profiles the peak memory usage of the model to determine how much
"""Profiles the peak memory usage of the model to determine how much
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
8632e831
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
bisect
import
bisect
import
gc
import
gc
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
cast
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
numpy
as
np
import
numpy
as
np
...
@@ -18,7 +18,8 @@ import vllm.envs as envs
...
@@ -18,7 +18,8 @@ import vllm.envs as envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
(
ParallelConfig
,
VllmConfig
,
get_layers_from_vllm_config
,
update_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
from
vllm.lora.layers
import
BaseLayerWithLoRA
...
@@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return
model_runner_output
return
model_runner_output
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names
=
{
"load_config"
,
"model_config"
}
for
config_name
,
config_overrides
in
overrides
.
items
():
assert
config_name
in
allowed_config_names
,
\
f
"Config `
{
config_name
}
` not supported. "
\
f
"Allowed configs:
{
allowed_config_names
}
"
config
=
getattr
(
self
,
config_name
)
new_config
=
update_config
(
config
,
config_overrides
)
setattr
(
self
,
config_name
,
new_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
...
...
vllm/v1/worker/tpu_worker.py
View file @
8632e831
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
"""A TPU worker class."""
import
os
import
os
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -260,6 +260,9 @@ class TPUWorker:
...
@@ -260,6 +260,9 @@ class TPUWorker:
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model_runner
.
load_model
()
self
.
model_runner
.
load_model
()
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
self
.
model_runner
.
update_config
(
overrides
)
def
compile_or_warm_up_model
(
self
)
->
None
:
def
compile_or_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
self
.
model_runner
.
capture_model
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment