Unverified Commit e8937954 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N] executor pass the complete config to worker/modelrunner (#9938)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1d4cfe2b
...@@ -7,10 +7,7 @@ import torch ...@@ -7,10 +7,7 @@ import torch
import torch.distributed import torch.distributed
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import ParallelConfig, VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
...@@ -27,7 +24,8 @@ from vllm.worker.cache_engine import CacheEngine ...@@ -27,7 +24,8 @@ from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -42,46 +40,31 @@ class Worker(LocalOrDistributedWorkerBase):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
observability_config: Optional[ObservabilityConfig] = None,
) -> None: ) -> None:
self.model_config = model_config WorkerBase.__init__(self, vllm_config)
self.parallel_config = parallel_config
self.parallel_config.rank = rank self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker: if is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \ assert rank % self.parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group." "Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.observability_config = observability_config
# Return hidden states from target model if the draft model is an # Return hidden states from target model if the draft model is an
# mlp_speculator # mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \ speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model == or (speculative_config.draft_model_config.model ==
model_config.model) \ model_config.model) \
...@@ -97,17 +80,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -97,17 +80,9 @@ class Worker(LocalOrDistributedWorkerBase):
elif self._is_encoder_decoder_model(): elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner ModelRunnerClass = EncoderDecoderModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config, vllm_config=self.vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
observability_config=observability_config,
**speculative_args, **speculative_args,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
......
...@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from vllm.config import ObservabilityConfig from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -29,6 +29,22 @@ class WorkerBase(ABC): ...@@ -29,6 +29,22 @@ class WorkerBase(ABC):
communicate request metadata to other workers. communicate request metadata to other workers.
""" """
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
@abstractmethod @abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
......
...@@ -10,9 +10,7 @@ import torch ...@@ -10,9 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -363,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -363,33 +361,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
self.model_config = model_config
self.parallel_config = parallel_config ModelRunnerBase.__init__(self, vllm_config=vllm_config)
self.scheduler_config = scheduler_config model_config = self.model_config
self.device_config = device_config cache_config = self.cache_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
if self.observability_config is not None:
print(f"observability_config is {self.observability_config}")
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.device = self.device_config.device self.device = self.device_config.device
......
...@@ -8,10 +8,7 @@ import oneccl_bindings_for_pytorch # noqa: F401 ...@@ -8,10 +8,7 @@ import oneccl_bindings_for_pytorch # noqa: F401
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -19,7 +16,7 @@ from vllm.model_executor import set_random_seed ...@@ -19,7 +16,7 @@ from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ...@@ -36,53 +33,32 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
) -> None: ) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
device_config = self.device_config
parallel_config = self.parallel_config
assert device_config.device_type == "xpu" assert device_config.device_type == "xpu"
assert current_platform.is_xpu() assert current_platform.is_xpu()
self.model_config = model_config
self.parallel_config = parallel_config
self.parallel_config.rank = rank self.parallel_config.rank = rank
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
self.observability_config = observability_config
if parallel_config and is_driver_worker: if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \ assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group." "Driver worker should be rank 0 of tensor parallel group."
self.model_runner = XPUModelRunner( # type: ignore self.model_runner = XPUModelRunner( # type: ignore
model_config, vllm_config=vllm_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype, kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
observability_config=self.observability_config,
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment