Unverified Commit 8632e831 authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[Core] Add `update_config` RPC method (#20095)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent 4bbfc36b
...@@ -7,7 +7,7 @@ import pytest ...@@ -7,7 +7,7 @@ import pytest
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
get_field) get_field, update_config)
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -46,6 +46,34 @@ def test_get_field(): ...@@ -46,6 +46,34 @@ def test_get_field():
assert c.default_factory is MISSING assert c.default_factory is MISSING
@dataclass
class _TestNestedConfig:
a: _TestConfigFields = field(
default_factory=lambda: _TestConfigFields(a=0))
def test_update_config():
# Simple update
config1 = _TestConfigFields(a=0)
new_config1 = update_config(config1, {"a": 42})
assert new_config1.a == 42
# Nonexistent field
with pytest.raises(AssertionError):
new_config1 = update_config(config1, {"nonexistent": 1})
# Nested update with dataclass
config2 = _TestNestedConfig()
new_inner_config = _TestConfigFields(a=1, c="new_value")
new_config2 = update_config(config2, {"a": new_inner_config})
assert new_config2.a == new_inner_config
# Nested update with dict
config3 = _TestNestedConfig()
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
assert new_config3.a.c == "new_value"
# Nested update with invalid type
with pytest.raises(AssertionError):
new_config3 = update_config(config3, {"a": "new_value"})
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"), ("model_id", "expected_runner_type", "expected_task"),
[ [
......
...@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): ...@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
model_runner.update_config({"do_not_exist_config": "dummy"})
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while # In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace # model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model() model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy" model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str( assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict()) model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str( assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict()) model_runner_2.get_model().state_dict())
......
...@@ -71,6 +71,7 @@ if TYPE_CHECKING: ...@@ -71,6 +71,7 @@ if TYPE_CHECKING:
ConfigType = type[DataclassInstance] ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]] HfOverrides = Union[dict, Callable[[type], type]]
else: else:
DataclassInstance = Any
PlacementGroup = Any PlacementGroup = Any
PretrainedConfig = Any PretrainedConfig = Any
ExecutorBase = Any ExecutorBase = Any
...@@ -87,7 +88,7 @@ else: ...@@ -87,7 +88,7 @@ else:
"vllm.model_executor.models") "vllm.model_executor.models")
logger = init_logger(__name__) logger = init_logger(__name__)
DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
ConfigT = TypeVar("ConfigT", bound=ConfigType) ConfigT = TypeVar("ConfigT", bound=ConfigType)
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
...@@ -5049,3 +5050,21 @@ class SpeechToTextConfig: ...@@ -5049,3 +5050,21 @@ class SpeechToTextConfig:
@property @property
def allow_audio_chunking(self) -> bool: def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
...@@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend ...@@ -19,7 +19,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
...@@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1728,6 +1728,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append(drafter_output.tolist()) draft_token_ids.append(drafter_output.tolist())
return draft_token_ids return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import copy import copy
import gc import gc
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -193,6 +193,9 @@ class Worker(WorkerBase): ...@@ -193,6 +193,9 @@ class Worker(WorkerBase):
with context: with context:
self.model_runner.load_model() self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
@torch.inference_mode() @torch.inference_mode()
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much """Profiles the peak memory usage of the model to determine how much
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import bisect import bisect
import gc import gc
import time import time
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch from unittest.mock import patch
import numpy as np import numpy as np
...@@ -18,7 +18,8 @@ import vllm.envs as envs ...@@ -18,7 +18,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config from vllm.config import (ParallelConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
...@@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -1111,6 +1112,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
return model_runner_output return model_runner_output
def update_config(self, overrides: dict[str, Any]) -> None:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None: def load_model(self) -> None:
self.device = self.device_config.device self.device = self.device_config.device
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class.""" """A TPU worker class."""
import os import os
from typing import Optional from typing import Any, Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -260,6 +260,9 @@ class TPUWorker: ...@@ -260,6 +260,9 @@ class TPUWorker:
def load_model(self) -> None: def load_model(self) -> None:
self.model_runner.load_model() self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
self.model_runner.capture_model() self.model_runner.capture_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment