Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
......@@ -7,10 +7,7 @@ import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
......@@ -27,7 +24,8 @@ from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
......@@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
WorkerBase.__init__(self, vllm_config)
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
if is_driver_worker:
assert rank % self.parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.observability_config = observability_config
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
......@@ -97,17 +80,9 @@ class Worker(LocalOrDistributedWorkerBase):
elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
vllm_config=self.vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
observability_config=observability_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by
......
......@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.config import ObservabilityConfig
from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
......@@ -29,6 +29,22 @@ class WorkerBase(ABC):
communicate request metadata to other workers.
"""
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
......
......@@ -10,9 +10,7 @@ import torch
import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
......@@ -363,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
model_config = self.model_config
cache_config = self.cache_config
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.observability_config is not None:
print(f"observability_config is {self.observability_config}")
self.return_hidden_states = return_hidden_states
self.device = self.device_config.device
......
......@@ -8,10 +8,7 @@ import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -19,7 +16,7 @@ from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
......@@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
device_config = self.device_config
parallel_config = self.parallel_config
assert device_config.device_type == "xpu"
assert current_platform.is_xpu()
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
self.observability_config = observability_config
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
self.model_runner = XPUModelRunner( # type: ignore
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
observability_config=self.observability_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment