Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
......@@ -138,13 +138,7 @@ def test_rotary_emb_replaced(dist_init):
enable_lora=True)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
vllm_config=engine_config,
is_driver_worker=True,
)
model_runner.load_model()
......
......@@ -4,7 +4,8 @@ import tempfile
from unittest.mock import patch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
......@@ -12,7 +13,7 @@ from vllm.worker.worker import Worker
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
worker = Worker(
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
task="auto",
......@@ -34,10 +35,13 @@ def test_worker_apply_lora(sql_lora_files):
gpu_memory_utilization=1.,
swap_space=0,
cache_dtype="auto"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
)
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
......
......@@ -81,12 +81,7 @@ def create_worker(cls: Callable[..., T],
get_ip(), get_open_port())
worker = cls(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
......
......@@ -19,14 +19,7 @@ def _create_model_runner(model: str, *args,
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
vllm_config=engine_config,
is_driver_worker=True,
)
return model_runner
......
......@@ -16,15 +16,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config,
vllm_config=engine_config,
is_driver_worker=True,
)
return model_runner
......
......@@ -24,12 +24,7 @@ def test_gpu_memory_profiling():
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
......
......@@ -19,12 +19,7 @@ def test_swap() -> None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
vllm_config=engine_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
......
import enum
import json
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
......@@ -1941,9 +1941,9 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}")
@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
@dataclass
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
......@@ -1953,11 +1953,11 @@ class EngineConfig:
scheduler_config: SchedulerConfig
device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
......@@ -1975,9 +1975,3 @@ class EngineConfig:
if self.prompt_adapter_config:
self.prompt_adapter_config.verify_with_model_config(
self.model_config)
def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
......@@ -9,10 +9,11 @@ import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
......@@ -955,7 +956,7 @@ class EngineArgs:
ignore_patterns=self.ignore_patterns,
)
def create_engine_config(self) -> EngineConfig:
def create_engine_config(self) -> VllmConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
......@@ -1167,7 +1168,7 @@ class EngineArgs:
or "all" in detailed_trace_modules,
)
return EngineConfig(
return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
......
......@@ -7,8 +7,8 @@ from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
from weakref import ReferenceType
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VllmConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
......@@ -604,7 +604,7 @@ class AsyncLLMEngine(EngineClient):
@classmethod
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
cls, engine_config: VllmConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
......@@ -663,7 +663,7 @@ class AsyncLLMEngine(EngineClient):
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
......
......@@ -13,8 +13,9 @@ import torch
from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig)
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
......@@ -219,7 +220,7 @@ class LLMEngine:
def __init__(
self,
vllm_config: EngineConfig,
vllm_config: VllmConfig,
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
......@@ -500,7 +501,7 @@ class LLMEngine:
@classmethod
def _get_executor_cls(cls,
engine_config: EngineConfig) -> Type[ExecutorBase]:
engine_config: VllmConfig) -> Type[ExecutorBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
......
......@@ -13,7 +13,7 @@ from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import PoolingParams
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
......@@ -78,7 +78,7 @@ class MQLLMEngineClient(EngineClient):
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: EngineConfig,
def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
......
......@@ -138,18 +138,11 @@ class CPUExecutor(ExecutorBase):
assert self.distributed_init_method is not None
kwargs = dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=self.distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=rank == 0,
)
wrapper.init_worker(**kwargs)
......
from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import EngineConfig
from vllm.config import VllmConfig
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -20,7 +20,7 @@ class ExecutorBase(ABC):
def __init__(
self,
vllm_config: EngineConfig,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
......
......@@ -49,21 +49,12 @@ class GPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
speculative_config=self.speculative_config,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
)
def _get_worker_module_and_class(
......
......@@ -29,11 +29,7 @@ class NeuronExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
......
......@@ -48,16 +48,10 @@ class OpenVINOExecutor(ExecutorBase):
get_ip(), get_open_port())
self.driver_worker = OpenVINOWorker(
ov_core=self.ov_core,
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
......
......@@ -44,12 +44,7 @@ class TPUExecutor(ExecutorBase):
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
......
......@@ -17,9 +17,6 @@ except (ModuleNotFoundError, ImportError) as err:
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.") from err
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
......@@ -49,40 +46,13 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
if return_hidden_states:
def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)
super().__init__(*args, **kwargs)
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
......
......@@ -21,7 +21,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["model_config"].get_vocab_size()
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
# Lazy initialization list.
self._proposer: Top1Proposer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment