Unverified Commit 022c5c69 authored by Rui Qiao's avatar Rui Qiao Committed by GitHub
Browse files

[V1] Refactor get_executor_cls (#11754)

parent f8fcca10
...@@ -8,8 +8,8 @@ from vllm import SamplingParams ...@@ -8,8 +8,8 @@ from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
...@@ -43,7 +43,7 @@ def test_engine_core(monkeypatch): ...@@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
"""Setup the EngineCore.""" """Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class)
...@@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch): ...@@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
"""Setup the EngineCore.""" """Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class) executor_class=executor_class)
......
...@@ -11,8 +11,8 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -11,8 +11,8 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
...@@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): ...@@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, multiprocess_mode=multiprocessing_mode,
asyncio_mode=False, asyncio_mode=False,
...@@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch): ...@@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( client = EngineCoreClient.make_client(
multiprocess_mode=True, multiprocess_mode=True,
asyncio_mode=True, asyncio_mode=True,
......
...@@ -22,7 +22,6 @@ from vllm.v1.engine.core_client import EngineCoreClient ...@@ -22,7 +22,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import initialize_ray_cluster
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -105,7 +104,7 @@ class AsyncLLM(EngineClient): ...@@ -105,7 +104,7 @@ class AsyncLLM(EngineClient):
else: else:
vllm_config = engine_config vllm_config = engine_config
executor_class = cls._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
# Create the AsyncLLM. # Create the AsyncLLM.
return cls( return cls(
...@@ -127,24 +126,6 @@ class AsyncLLM(EngineClient): ...@@ -127,24 +126,6 @@ class AsyncLLM(EngineClient):
if handler := getattr(self, "output_handler", None): if handler := getattr(self, "output_handler", None):
handler.cancel() handler.cancel()
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
initialize_ray_cluster(vllm_config.parallel_config)
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
......
...@@ -89,7 +89,7 @@ class LLMEngine: ...@@ -89,7 +89,7 @@ class LLMEngine:
# Create the engine configs. # Create the engine configs.
vllm_config = engine_args.create_engine_config(usage_context) vllm_config = engine_args.create_engine_config(usage_context)
executor_class = cls._get_executor_cls(vllm_config) executor_class = Executor.get_class(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING: if VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.") logger.debug("Enabling multiprocessing for LLMEngine.")
...@@ -103,24 +103,6 @@ class LLMEngine: ...@@ -103,24 +103,6 @@ class LLMEngine:
stat_loggers=stat_loggers, stat_loggers=stat_loggers,
multiprocess_mode=enable_multiprocessing) multiprocess_mode=enable_multiprocessing)
@classmethod
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
return self.detokenizer.get_num_unfinished_requests() return self.detokenizer.get_num_unfinished_requests()
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple from typing import Tuple, Type
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
...@@ -8,6 +8,23 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -8,6 +8,23 @@ from vllm.v1.outputs import ModelRunnerOutput
class Executor(ABC): class Executor(ABC):
"""Abstract class for executors.""" """Abstract class for executors."""
@staticmethod
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
executor_class: Type[Executor]
distributed_executor_backend = (
vllm_config.parallel_config.distributed_executor_backend)
if distributed_executor_backend == "ray":
from vllm.v1.executor.ray_executor import RayExecutor
executor_class = RayExecutor
elif distributed_executor_backend == "mp":
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
executor_class = MultiprocExecutor
else:
assert (distributed_executor_backend is None)
from vllm.v1.executor.uniproc_executor import UniprocExecutor
executor_class = UniprocExecutor
return executor_class
@abstractmethod @abstractmethod
def __init__(self, vllm_config: VllmConfig) -> None: def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment