Commit 236266a9 authored by xuxz's avatar xuxz
Browse files

[PD]支持dp的分支

parent fb597c49
......@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls,
config: "VllmConfig",
role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
......@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
return connector_cls(config, role, dp_rank)
# Register various connectors here.
......
......@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
......@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class DuSwiftConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
......@@ -157,10 +160,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
except Exception as e:
print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key):
return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ==============================
# Worker-side methods
......@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1):
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
# ip, port = self.parse_request_id(request_id, True)
# p_ip, p_port = self.parse_request_id(request_id, False)
# remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
# pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
# ) % self.parallel_config.pipeline_parallel_size
# if (self.multiple_machines_p and self.multiple_machines_d):
# ip_second = self.get_ip_value(ip)
# if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank))
# else:
# logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.multiple_machines_p and not self.multiple_machines_d):
# if (self.pp_size == 2):
# remote_address = ip + ":" + str(port + self._tp_rank)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
# elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
kv_cache, remote_addr)
def wait_for_save(self):
pass
# if self.is_producer:
......
......@@ -6,9 +6,10 @@ import os
import threading
import time
import typing
from collections import deque
from collections import deque, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack
import torch
......@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
......@@ -125,6 +132,10 @@ class DuSwiftEngine:
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
......@@ -135,9 +146,14 @@ class DuSwiftEngine:
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
if self.instance_ip:
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
......@@ -148,16 +164,27 @@ class DuSwiftEngine:
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
......@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
daemon=True)
self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
......@@ -328,7 +360,7 @@ class DuSwiftEngine:
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
......@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
......@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
......@@ -416,6 +448,55 @@ class DuSwiftEngine:
self.buffer_size / self.buffer_size_threshold * 100)
return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor(
self,
......@@ -475,22 +556,12 @@ class DuSwiftEngine:
def _listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
socks = dict(self.poller.poll(5000))
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
logger.info(f"unexpected message from {remote_address.decode()}")
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
......@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
......@@ -814,20 +894,31 @@ class DuSwiftEngine:
"""
# Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
with self.pending_queue_cv:
if request_id in self.pending_queue:
self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
......@@ -835,6 +926,19 @@ class DuSwiftEngine:
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None
def _ping(self):
......@@ -911,6 +1015,7 @@ class DuSwiftEngine:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
......
......@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
......@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface):
if request.is_finished():
self.waiting.pop_request()
continue
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
......
......@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment