Commit 236266a9 authored by xuxz's avatar xuxz
Browse files

[PD]支持dp的分支

parent fb597c49
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from typing import Any
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
import time
import asyncio
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class Request:
request_id: str
p_http_address: str = ""
p_dp_rank: int = -1
d_http_address: str = ""
d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
running_requests: dict[str, Request] = {}
healthy_instances: dict[str, float] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
instance_cv = threading.Condition()
request_cv = threading.Condition()
health_cv = threading.Condition()
request_queue_cv = threading.Condition()
request_queue: deque[list[Any]] = deque()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "heartbeat":
global healthy_instances
global health_cv
with health_cv:
healthy_instances[data["http_address"]] = time.time()
elif data["type"] == "Req":
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global running_requests
global request_cv
with request_cv:
if data["request_id"] in running_requests:
request = running_requests[data["request_id"]]
if data["instance_type"] == "P":
request.p_http_address = data["http_address"]
request.p_dp_rank = int(data["dp_rank"])
elif data["instance_type"] == "D":
request.d_http_address = data["http_address"]
request.d_dp_rank = int(data["dp_rank"])
assert(request.p_dp_rank >= 0 and request.d_dp_rank >=0)
with request_queue_cv:
request_queue.append(request)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv.notify()
else:
if data["instance_type"] == "P":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], p_http_address=data["http_address"], p_dp_rank=int(data["dp_rank"]))
elif data["instance_type"] == "D":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], d_http_address=data["http_address"], d_dp_rank=int(data["dp_rank"]))
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
zmq_context = None
tp_mapping_of_pd_pair : dict[str, dict[int, list[str]]] = {}
tp_comm_mapping_of_pd_pair : dict[str, dict[int, list[int]]] = {}
active_p_tp_rank_of_pd_pair : dict[str, set[int]] = {}
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
def dispatch_to_P(request : Request):
global prefill_instances
global decode_instances
p_ins = prefill_instances[request.p_http_address]
d_ins = decode_instances[request.d_http_address]
global zmq_context
global sock_cache
pd_pair_id = p_ins.http_address + "_" + d_ins.http_address
p_dp_rank = request.p_dp_rank
d_dp_rank = request.d_dp_rank
tp_dst_id = pd_pair_id + "_" + str(d_dp_rank)
assert(d_ins.pp_size == 1)
d_pp_rank = 0
global tp_mapping_of_pd_pair
global tp_comm_mapping_of_pd_pair
global active_p_tp_rank_of_pd_pair
if tp_dst_id not in active_p_tp_rank_of_pd_pair:
p_active_tp_rank = set()
p_tp_rank_to_dst : dict[int, list[str]] = defaultdict(list)
p_tp_rank_to_dst_comm : dict[int, list[int]] = defaultdict(list)
for d_tp_rank in range(d_ins.tp_size):
p_tp_rank = d_tp_rank % p_ins.tp_size
p_active_tp_rank.add(p_tp_rank)
p_tp_rank_to_dst[p_tp_rank].append(d_ins.rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
p_tp_rank_to_dst_comm[p_tp_rank].append(d_ins.comm_rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
tp_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair[tp_dst_id] = p_active_tp_rank
p_active_tp_rank = active_p_tp_rank_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst = tp_mapping_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst_comm = tp_comm_mapping_of_pd_pair[tp_dst_id]
for p_pp_rank in range(p_ins.pp_size):
for p_tp_rank in p_active_tp_rank:
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_to_transfer",
"request_id": request.request_id,
"dst_num": len(p_tp_rank_to_dst[p_tp_rank]),
"pd_pair_id": pd_pair_id,
"remote_address": p_tp_rank_to_dst[p_tp_rank],
"remote_rank": p_tp_rank_to_dst_comm[p_tp_rank],
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
logger.info(f"""[Router] dispatch Request {request.request_id} [{p_dp_rank}, {p_pp_rank}, {p_tp_rank}] -> [{d_dp_rank}, {d_pp_rank}]""")
for p_tp_rank in range(p_ins.tp_size):
if p_tp_rank not in p_active_tp_rank:
for p_pp_rank in range(p_ins.pp_size):
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_not_to_transfer",
"request_id": request.request_id,
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
def dp_dispatch():
global request_queue_cv
global request_queue
while True:
with request_queue_cv:
while not request_queue:
request_queue_cv.wait()
request = request_queue.pop()
dispatch_to_P(request)
def start_dp_dispatch():
_thread = threading.Thread(
target=dp_dispatch, daemon=True
)
_thread.start()
return _thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
def unique_id_dispatch(prefill_instance : Instance,
decode_instance : Instance) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.http_address + "_" + decode_instance.http_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global instance_cv
with instance_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
with instance_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_instance = decode_list[count % len(decode_list)]
global pd_pair
if prefill_instance.http_address + "_" + decode_instance.http_address not in pd_pair:
raise RuntimeError("Selected PD pair was not inited")
logger.info(
f"handle_request count: {count}, [HTTP:{prefill_addr}, 👉 HTTP:{decode_addr}]"
)
count += 1
request_id = f"{random_uuid()}"
async def run_prefill():
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
pass
prefill_task = asyncio.create_task(run_prefill())
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
t_2 = start_dp_dispatch()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
t_2.join()
\ No newline at end of file
......@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls,
config: "VllmConfig",
role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, "
......@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
return connector_cls(config, role, dp_rank)
# Register various connectors here.
......
......@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
......@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class DuSwiftConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
......@@ -157,10 +160,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
except Exception as e:
print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key):
return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ==============================
# Worker-side methods
......@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1):
else:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
# ip, port = self.parse_request_id(request_id, True)
# p_ip, p_port = self.parse_request_id(request_id, False)
# remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
# pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
# ) % self.parallel_config.pipeline_parallel_size
# if (self.multiple_machines_p and self.multiple_machines_d):
# ip_second = self.get_ip_value(ip)
# if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank))
# else:
# logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.multiple_machines_p and not self.multiple_machines_d):
# if (self.pp_size == 2):
# remote_address = ip + ":" + str(port + self._tp_rank)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
# elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
kv_cache, remote_addr)
def wait_for_save(self):
pass
# if self.is_producer:
......
......@@ -6,9 +6,10 @@ import os
import threading
import time
import typing
from collections import deque
from collections import deque, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack
import torch
......@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
......@@ -125,6 +132,10 @@ class DuSwiftEngine:
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
......@@ -135,9 +146,14 @@ class DuSwiftEngine:
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
if self.instance_ip:
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
......@@ -148,16 +164,27 @@ class DuSwiftEngine:
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
......@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
daemon=True)
self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
......@@ -328,7 +360,7 @@ class DuSwiftEngine:
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
......@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
......@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
......@@ -416,6 +448,55 @@ class DuSwiftEngine:
self.buffer_size / self.buffer_size_threshold * 100)
return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor(
self,
......@@ -475,22 +556,12 @@ class DuSwiftEngine:
def _listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
socks = dict(self.poller.poll(5000))
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
logger.info(f"unexpected message from {remote_address.decode()}")
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
......@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
......@@ -814,20 +894,31 @@ class DuSwiftEngine:
"""
# Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
with self.pending_queue_cv:
if request_id in self.pending_queue:
self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
......@@ -835,6 +926,19 @@ class DuSwiftEngine:
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None
def _ping(self):
......@@ -911,6 +1015,7 @@ class DuSwiftEngine:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
......
......@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported "
"with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
......@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface):
if request.is_finished():
self.waiting.pop_request()
continue
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running):
break
request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
......@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
......
......@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment