from typing import Any, Dict, List, Optional, Tuple from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class UniProcExecutor(ExecutorBase): uses_ray: bool = False def _init_executor(self) -> None: """Initialize the worker and load the model. """ self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rank=0) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) local_rank = 0 rank = 0 kwargs = dict( vllm_config=self.vllm_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), ) self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") self.collective_rpc("load_model") def collective_rpc(self, method: str, timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: if kwargs is None: kwargs = {} try: func = getattr(self.driver_worker, method) except AttributeError: raise NotImplementedError(f"Method {method} is not implemented.") \ from None answer = func(*args, **kwargs) return [answer] def check_health(self) -> None: # UniProcExecutor will always be healthy as long as # it's running. return UniProcExecutorAsync = UniProcExecutor