Unverified Commit 23027e2d authored by CYJiang's avatar CYJiang Committed by GitHub
Browse files

[Misc] refactor: simplify EngineCoreClient.make_async_mp_client in AsyncLLM (#18817)


Signed-off-by: default avatargoogs1025 <googs1025@gmail.com>
parent c3fd4d66
...@@ -28,8 +28,7 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs ...@@ -28,8 +28,7 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, from vllm.v1.engine.core_client import EngineCoreClient
RayDPClient)
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor, from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector) RequestOutputCollector)
...@@ -121,15 +120,8 @@ class AsyncLLM(EngineClient): ...@@ -121,15 +120,8 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats) log_stats=self.log_stats)
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
core_client_class: type[AsyncMPClient]
if vllm_config.parallel_config.data_parallel_size == 1: self.engine_core = EngineCoreClient.make_async_mp_client(
core_client_class = AsyncMPClient
elif vllm_config.parallel_config.data_parallel_backend == "ray":
core_client_class = RayDPClient
else:
core_client_class = DPAsyncMPClient
self.engine_core = core_client_class(
vllm_config=vllm_config, vllm_config=vllm_config,
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats, log_stats=self.log_stats,
......
...@@ -68,18 +68,31 @@ class EngineCoreClient(ABC): ...@@ -68,18 +68,31 @@ class EngineCoreClient(ABC):
"is not currently supported.") "is not currently supported.")
if multiprocess_mode and asyncio_mode: if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1: return EngineCoreClient.make_async_mp_client(
if vllm_config.parallel_config.data_parallel_backend == "ray": vllm_config, executor_class, log_stats)
return RayDPClient(vllm_config, executor_class, log_stats)
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats)
if multiprocess_mode and not asyncio_mode: if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats) return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "MPClient":
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return DPAsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return AsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
@abstractmethod @abstractmethod
def shutdown(self): def shutdown(self):
... ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment