"csrc/smxx/vscode:/vscode.git/clone" did not exist on "5d8e93f67a1bf5f96213ffe7e7f64633a8c0e8ea"
Commit 61ba33d5 authored by xuxz's avatar xuxz Committed by xuxz
Browse files

[PD][Feat]支持pd分离dp并行

parent ce47a56e
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from vllm import envs
from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast
......@@ -45,6 +46,7 @@ class KVConnectorFactory:
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
dp_rank: int = -1,
) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
......@@ -77,6 +79,8 @@ class KVConnectorFactory:
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role)
elif envs.VLLM_USE_DP_CONNECTOR:
return connector_cls(config, role, kv_cache_config, dp_rank)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
......@@ -160,6 +164,11 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector",
"DuSwiftConnector")
KVConnectorFactory.register_connector(
"DuSwiftConnectorDp",
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector_dp",
"DuSwiftConnectorDp")
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
......
......@@ -1841,6 +1841,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
#vllm use dp connector
"VLLM_USE_DP_CONNECTOR":
lambda: bool(int(os.getenv("VLLM_USE_DP_CONNECTOR", "0"))),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
......
......@@ -121,7 +121,7 @@ class Scheduler(SchedulerInterface):
config=self.vllm_config,
role=KVConnectorRole.SCHEDULER,
kv_cache_config=self.kv_cache_config,
)
dp_rank=self.parallel_config.data_parallel_rank)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
kv_load_failure_policy = (
......@@ -556,6 +556,12 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and \
request.request_id not in self.finished_recving_kv_req_ids and \
envs.VLLM_USE_DP_CONNECTOR:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......
......@@ -66,6 +66,7 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import compute_iteration_details
from vllm.version import __version__ as VLLM_VERSION
from vllm import envs
logger = init_logger(__name__)
......@@ -1155,6 +1156,11 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
if isinstance(request, tuple) and self.scheduler.connector is not None \
and envs.VLLM_USE_DP_CONNECTOR:
req, _ = request
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(req.request_id)
def process_output_sockets(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment