"vscode:/vscode.git/clone" did not exist on "7ab80a8e37ac62f73e09fdc9ba7f69dd09cda2a8"
Unverified Commit 18bd7587 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[1/N] pass the complete config from engine to executor (#9933)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 598b6d7b
...@@ -680,7 +680,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -680,7 +680,7 @@ class AsyncLLMEngine(EngineClient):
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), vllm_config=engine_config,
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
......
...@@ -13,11 +13,8 @@ import torch ...@@ -13,11 +13,8 @@ import torch
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig)
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs) SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
...@@ -222,17 +219,7 @@ class LLMEngine: ...@@ -222,17 +219,7 @@ class LLMEngine:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: EngineConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -240,6 +227,22 @@ class LLMEngine: ...@@ -240,6 +227,22 @@ class LLMEngine:
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, " "model=%r, speculative_config=%r, tokenizer=%r, "
...@@ -340,18 +343,7 @@ class LLMEngine: ...@@ -340,18 +343,7 @@ class LLMEngine:
self.input_processor = input_registry.create_input_processor( self.input_processor = input_registry.create_input_processor(
model_config) model_config)
self.model_executor = executor_class( self.model_executor = executor_class(vllm_config=vllm_config, )
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
if self.model_config.task != "embedding": if self.model_config.task != "embedding":
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -582,7 +574,7 @@ class LLMEngine: ...@@ -582,7 +574,7 @@ class LLMEngine:
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), vllm_config=engine_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
......
...@@ -7,8 +7,6 @@ import cloudpickle ...@@ -7,8 +7,6 @@ import cloudpickle
import zmq import zmq
from vllm import AsyncEngineArgs, SamplingParams from vllm import AsyncEngineArgs, SamplingParams
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...@@ -30,9 +28,6 @@ if VLLM_USE_V1: ...@@ -30,9 +28,6 @@ if VLLM_USE_V1:
else: else:
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
logger = init_logger(__name__) logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000 POLLING_TIMEOUT_MS = 10000
...@@ -130,7 +125,7 @@ class MQLLMEngine: ...@@ -130,7 +125,7 @@ class MQLLMEngine:
return cls(ipc_path=ipc_path, return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets, use_async_sockets=use_async_sockets,
**engine_config.to_dict(), vllm_config=engine_config,
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import EngineConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -23,27 +20,19 @@ class ExecutorBase(ABC): ...@@ -23,27 +20,19 @@ class ExecutorBase(ABC):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: EngineConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None: ) -> None:
self.model_config = model_config self.vllm_config = vllm_config
self.cache_config = cache_config self.model_config = vllm_config.model_config
self.lora_config = lora_config self.cache_config = vllm_config.cache_config
self.load_config = load_config self.lora_config = vllm_config.lora_config
self.parallel_config = parallel_config self.load_config = vllm_config.load_config
self.scheduler_config = scheduler_config self.parallel_config = vllm_config.parallel_config
self.device_config = device_config self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = speculative_config self.device_config = vllm_config.device_config
self.prompt_adapter_config = prompt_adapter_config self.speculative_config = vllm_config.speculative_config
self.observability_config = observability_config self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
@abstractmethod @abstractmethod
......
...@@ -2,10 +2,7 @@ from typing import Callable, List, Optional, Tuple, Type, Union ...@@ -2,10 +2,7 @@ from typing import Callable, List, Optional, Tuple, Type, Union
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import ModelConfig, ParallelConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor): ...@@ -21,38 +18,13 @@ class XPUExecutor(GPUExecutor):
uses_ray: bool = False uses_ray: bool = False
def __init__( def _init_executor(self) -> None:
self, assert self.device_config.device_type == "xpu"
model_config: ModelConfig, assert self.speculative_config is None, (
cache_config: CacheConfig, "Speculative decoding not yet supported for XPU backend")
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, self.model_config = _verify_and_get_model_config(self.model_config)
device_config: DeviceConfig, GPUExecutor._init_executor(self)
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"
model_config = _verify_and_get_model_config(model_config)
self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = _verify_and_get_parallel_config(parallel_config)
self.scheduler_config = scheduler_config
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _get_worker_module_and_class( def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
......
...@@ -2,11 +2,8 @@ import time ...@@ -2,11 +2,8 @@ import time
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union) Union)
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig)
ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
...@@ -35,17 +32,7 @@ class LLMEngine: ...@@ -35,17 +32,7 @@ class LLMEngine:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: EngineConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[GPUExecutor], executor_class: Type[GPUExecutor],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
...@@ -53,6 +40,22 @@ class LLMEngine: ...@@ -53,6 +40,22 @@ class LLMEngine:
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
# TODO: remove the local variables and use self.* throughout the class.
model_config = self.model_config = vllm_config.model_config
cache_config = self.cache_config = vllm_config.cache_config
lora_config = self.lora_config = vllm_config.lora_config
parallel_config = self.parallel_config = vllm_config.parallel_config
scheduler_config = self.scheduler_config = vllm_config.scheduler_config
device_config = self.device_config = vllm_config.device_config
speculative_config = self.speculative_config = vllm_config.speculative_config # noqa
load_config = self.load_config = vllm_config.load_config
decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
)
prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa
observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
)
# Override the configs for V1. # Override the configs for V1.
# FIXME # FIXME
if usage_context == UsageContext.LLM_CLASS: if usage_context == UsageContext.LLM_CLASS:
...@@ -112,18 +115,6 @@ class LLMEngine: ...@@ -112,18 +115,6 @@ class LLMEngine:
model_config.mm_processor_kwargs, model_config.mm_processor_kwargs,
) )
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats self.log_stats = log_stats
assert not self.model_config.skip_tokenizer_init assert not self.model_config.skip_tokenizer_init
...@@ -154,18 +145,7 @@ class LLMEngine: ...@@ -154,18 +145,7 @@ class LLMEngine:
# Request id -> RequestOutput # Request id -> RequestOutput
self.request_outputs: Dict[str, RequestOutput] = {} self.request_outputs: Dict[str, RequestOutput] = {}
self.model_executor = executor_class( self.model_executor = executor_class(vllm_config=vllm_config)
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
lora_config=lora_config,
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
assert self.model_config.task != "embedding" assert self.model_config.task != "embedding"
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -203,7 +183,7 @@ class LLMEngine: ...@@ -203,7 +183,7 @@ class LLMEngine:
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), vllm_config=engine_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment