Commit 236266a9 authored by xuxz's avatar xuxz
Browse files

[PD]支持dp的分支

parent fb597c49
...@@ -54,6 +54,7 @@ class KVConnectorFactory: ...@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1: ) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, " raise ValueError("Attempting to initialize a V1 Connector, "
...@@ -81,7 +82,7 @@ class KVConnectorFactory: ...@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role, dp_rank)
# Register various connectors here. # Register various connectors here.
......
...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata): ...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class DuSwiftConnector(KVConnectorBase_V1): class DuSwiftConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
...@@ -158,9 +161,38 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -158,9 +161,38 @@ class DuSwiftConnector(KVConnectorBase_V1):
print(f"Error: Exception occurred while reading configuration file - {e}") print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key): def get_ip_value(self, key):
return self.ip_map.get(key) return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ============================== # ==============================
# Worker-side methods # Worker-side methods
...@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1):
else: else:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) # ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False) # p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank) # remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port # pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size # pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size # ) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d): # if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip) # ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1): # if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2): # elif (self.pp_size == 2):
# if (pp_rank == 0): # if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else: # else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, str(ip_second) + ":" + str(port + self._rank))
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # else:
# kv_cache, ip + ":" + str(port + self._rank - 4)) # logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.pp_size == 8): # elif (self.multiple_machines_p and not self.multiple_machines_d):
# for i in range(8): # if (self.pp_size == 2):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # remote_address = ip + ":" + str(port + self._tp_rank)
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# else: # else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!") # logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
else:
logger.error("Error: not support!!!!!!") # elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_addr)
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
......
...@@ -6,9 +6,10 @@ import os ...@@ -6,9 +6,10 @@ import os
import threading import threading
import time import time
import typing import typing
from collections import deque from collections import deque, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack import msgpack
import torch import torch
...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str): ...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass @dataclass
class RemoteAddr: class RemoteAddr:
pd_pair_id: str = "" pd_pair_id: str = ""
...@@ -125,6 +132,10 @@ class DuSwiftEngine: ...@@ -125,6 +132,10 @@ class DuSwiftEngine:
self.multp = int(self.remote_tp_size / self.tp_size) self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config( self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False) "enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
...@@ -135,6 +146,11 @@ class DuSwiftEngine: ...@@ -135,6 +146,11 @@ class DuSwiftEngine:
self.zmq_address = f"{self._hostname}:{self._port}" self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI. # The `http_port` must be consistent with the port of OpenAI.
if self.instance_ip:
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = ( self.http_address = (
f"{self._hostname}:" f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}") f"{self.config.kv_connector_extra_config['http_port']}")
...@@ -148,16 +164,27 @@ class DuSwiftEngine: ...@@ -148,16 +164,27 @@ class DuSwiftEngine:
else: else:
self.proxy_address = proxy_ip + ":" + proxy_port self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context() self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER) self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}") self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition() self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition() self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
...@@ -181,11 +208,16 @@ class DuSwiftEngine: ...@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self._send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape) # tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {} self.recv_store: dict[str, Any] = {}
...@@ -328,7 +360,7 @@ class DuSwiftEngine: ...@@ -328,7 +360,7 @@ class DuSwiftEngine:
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None, tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
...@@ -356,7 +388,7 @@ class DuSwiftEngine: ...@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank) self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
...@@ -364,7 +396,7 @@ class DuSwiftEngine: ...@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape, remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size, self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
...@@ -417,6 +449,55 @@ class DuSwiftEngine: ...@@ -417,6 +449,55 @@ class DuSwiftEngine:
return True return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor( def recv_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -475,22 +556,12 @@ class DuSwiftEngine: ...@@ -475,22 +556,12 @@ class DuSwiftEngine:
def _listen_for_requests(self): def _listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll(5000))
if self.router_socket in socks: if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart() remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes( logger.info(f"unexpected message from {remote_address.decode()}")
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data: if "tensor_split_num" in data:
...@@ -577,6 +648,15 @@ class DuSwiftEngine: ...@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger.info( logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank) self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
with self.send_store_cv: with self.send_store_cv:
...@@ -814,20 +894,31 @@ class DuSwiftEngine: ...@@ -814,20 +894,31 @@ class DuSwiftEngine:
""" """
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: with self.pending_queue_cv:
tensor_id = request_id + "#" + layer_name if request_id in self.pending_queue:
if tensor_id in self.recv_store: self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv: with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None) tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache. # TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set() finished_sending: set[str] = set()
...@@ -835,6 +926,19 @@ class DuSwiftEngine: ...@@ -835,6 +926,19 @@ class DuSwiftEngine:
# TODO:Retrieve requests that have already received the KV cache. # TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set() finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
...@@ -911,6 +1015,7 @@ class DuSwiftEngine: ...@@ -911,6 +1015,7 @@ class DuSwiftEngine:
self._listener_thread.join() self._listener_thread.join()
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread.join() self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
......
...@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface): ...@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
"with KV connectors") "with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
...@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface): ...@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface):
if request.is_finished(): if request.is_finished():
self.waiting.pop_request() self.waiting.pop_request()
continue continue
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface): ...@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running): + len(scheduled_running_reqs) >= max_batch_running):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface): ...@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
......
...@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore): ...@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str], def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str], coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment