Commit 56cc2ac8 authored by xuxz's avatar xuxz
Browse files

[PD]修改代理

parent 56fef1c3
...@@ -4,35 +4,87 @@ ...@@ -4,35 +4,87 @@
import os import os
import socket import socket
import threading import threading
import time
import uuid import uuid
from typing import Any
import aiohttp import aiohttp
import msgpack import msgpack
import zmq import zmq
from typing import Any
from quart import Quart, make_response, request from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
count = 0 def count_rank_table_elements(self):
prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) count = 0
decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
prefill_cv = threading.Condition() count = 0
decode_cv = threading.Condition() # prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
DEFAULT_PING_SECONDS = 5 pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
def _remove_oldest_instances(instances: dict[str, Any]) -> None: prefill_cv = threading.Condition()
oldest_key = next(iter(instances), None) decode_cv = threading.Condition()
while oldest_key is not None: instance_cv = threading.Condition()
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
...@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket): ...@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port", # data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"} # "zmq_address": "ip:port"}
data = msgpack.loads(message) data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P": if data["type"] == "P":
global prefill_instances with instance_cv:
global prefill_cv if data["http_address"] not in prefill_instances:
with prefill_cv: prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
node = prefill_instances.get(data["http_address"], None) p_instance = prefill_instances[data["http_address"]]
prefill_instances[data["http_address"]] = ( p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
data["zmq_address"], if p_instance.is_ready():
time.time() + DEFAULT_PING_SECONDS, pending_prefill_ins.append(p_instance.http_address)
) logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
_remove_oldest_instances(prefill_instances) instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances with instance_cv:
global decode_cv if data["http_address"] not in decode_instances:
with decode_cv: decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
node = decode_instances.get(data["http_address"], None) d_instance = decode_instances[data["http_address"]]
decode_instances[data["http_address"]] = ( d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
data["zmq_address"], if d_instance.is_ready():
time.time() + DEFAULT_PING_SECONDS, pending_decode_ins.append(d_instance.http_address)
) logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
_remove_oldest_instances(decode_instances) instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
remote_address, remote_address,
data, data,
) )
return
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
zmq_context = None
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
hostname = socket.gethostname() hostname = socket.gethostname()
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
context = zmq.Context() # context = zmq.Context()
router_socket = context.socket(zmq.ROUTER) # router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}") router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() poller = zmq.Poller()
...@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id): ...@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id):
yield content yield content
def unique_id_dispatch(prefill_instance : str,
decode_instance : str) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
original_request_data = await request.get_json() original_request_data = await request.get_json()
...@@ -129,45 +317,42 @@ async def handle_request(): ...@@ -129,45 +317,42 @@ async def handle_request():
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1 prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
global count global count
global prefill_instances global prefill_instances
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_instance = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, " f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]" f"ZMQ:{decode_instance.zmq_address}]"
) )
count += 1 count += 1
request_id = ( request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}" f"{decode_instance.zmq_address}_{random_uuid()}"
) )
# finish prefill # finish prefill
async for _ in forward_request( async for _ in forward_request(
f"http://{prefill_addr}{request.path}", prefill_request, request_id f"http://{prefill_addr}/v1/completions", prefill_request, request_id
): ):
continue continue
# return decode # return decode
generator = forward_request( generator = forward_request(
f"http://{decode_addr}{request.path}", original_request_data, request_id f"http://{decode_addr}/v1/completions", original_request_data, request_id
) )
response = await make_response(generator) response = await make_response(generator)
response.timeout = None response.timeout = None
...@@ -186,5 +371,7 @@ async def handle_request(): ...@@ -186,5 +371,7 @@ async def handle_request():
if __name__ == "__main__": if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001) t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001) app.run(host="0.0.0.0", port=10001)
t.join() t.join()
t_1.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment