Unverified Commit 66a9e713 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Pipe `worker_class_fn` argument in Executor (#7707)

parent 9e51b6a6
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest ...@@ -7,15 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
def create_worker(worker_module_name, worker_class_name, **kwargs): def create_worker(worker_module_name: str, worker_class_name: str,
worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
**kwargs):
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name, worker_module_name=worker_module_name,
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
) )
wrapper.init_worker(**kwargs) wrapper.init_worker(**kwargs)
return wrapper.worker return wrapper.worker
...@@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase): ...@@ -62,7 +65,9 @@ class GPUExecutor(ExecutorBase):
observability_config=self.observability_config, observability_config=self.observability_config,
) )
def _get_worker_module_and_class(self) -> Tuple[str, str]: def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker" worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker" worker_class_name = "MultiStepWorker"
...@@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase): ...@@ -72,7 +77,7 @@ class GPUExecutor(ExecutorBase):
else: else:
worker_module_name = "vllm.worker.worker" worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker" worker_class_name = "Worker"
return (worker_module_name, worker_class_name) return (worker_module_name, worker_class_name, worker_class_fn)
def _get_create_worker_kwargs( def _get_create_worker_kwargs(
self, self,
...@@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase): ...@@ -82,10 +87,13 @@ class GPUExecutor(ExecutorBase):
worker_kwargs = self._get_worker_kwargs(local_rank, rank, worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method) distributed_init_method)
(worker_module_name, (worker_module_name, worker_class_name,
worker_class_name) = self._get_worker_module_and_class() worker_class_fn) = self._get_worker_module_and_class()
worker_kwargs.update(worker_module_name=worker_module_name, worker_kwargs.update(
worker_class_name=worker_class_name) worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
)
return worker_kwargs return worker_kwargs
......
...@@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -91,12 +91,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]: def _get_worker_wrapper_args(self) -> Dict[str, Any]:
(worker_module_name, (worker_module_name, worker_class_name,
worker_class_name) = self._get_worker_module_and_class() worker_class_fn) = self._get_worker_module_and_class()
return dict( return dict(
worker_module_name=worker_module_name, worker_module_name=worker_module_name,
worker_class_name=worker_class_name, worker_class_name=worker_class_name,
worker_class_fn=worker_class_fn,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
......
from typing import List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Type, Union
import torch import torch
...@@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor ...@@ -11,6 +11,7 @@ from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor): ...@@ -52,14 +53,16 @@ class XPUExecutor(GPUExecutor):
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
self._init_executor() self._init_executor()
def _get_worker_module_and_class(self) -> Tuple[str, str]: def _get_worker_module_and_class(
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
worker_class_fn = None
if self.speculative_config is not None: if self.speculative_config is not None:
raise NotImplementedError( raise NotImplementedError(
"XPU does not support speculative decoding") "XPU does not support speculative decoding")
else: else:
worker_module_name = "vllm.worker.xpu_worker" worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker" worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name) return (worker_module_name, worker_class_name, worker_class_fn)
def execute_model( def execute_model(
self, execute_model_req: ExecuteModelRequest self, execute_model_req: ExecuteModelRequest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment