Commit 294cc23a authored by xiabo's avatar xiabo
Browse files

解决pd分离非对称切分通信组过多问题

parent 84e5aba2
...@@ -9,15 +9,82 @@ import uuid ...@@ -9,15 +9,82 @@ import uuid
import aiohttp import aiohttp
import msgpack import msgpack
import zmq import zmq
from typing import Any
from quart import Quart, make_response, request from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0 count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address # prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address # decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
prefill_cv = threading.Condition() prefill_cv = threading.Condition()
decode_cv = threading.Condition() decode_cv = threading.Condition()
instance_cv = threading.Condition()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
...@@ -27,16 +94,61 @@ def _listen_for_register(poller, router_socket): ...@@ -27,16 +94,61 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port", # data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"} # "zmq_address": "ip:port"}
data = msgpack.loads(message) data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P": if data["type"] == "P":
global prefill_instances with instance_cv:
global prefill_cv if data["http_address"] not in prefill_instances:
with prefill_cv: prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
prefill_instances[data["http_address"]] = data["zmq_address"] p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances with instance_cv:
global decode_cv if data["http_address"] not in decode_instances:
with decode_cv: decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
decode_instances[data["http_address"]] = data["zmq_address"] d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
...@@ -44,15 +156,19 @@ def _listen_for_register(poller, router_socket): ...@@ -44,15 +156,19 @@ def _listen_for_register(poller, router_socket):
data, data,
) )
zmq_context = None
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
hostname = socket.gethostname() hostname = socket.gethostname()
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
context = zmq.Context() # context = zmq.Context()
router_socket = context.socket(zmq.ROUTER) # router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}") router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() poller = zmq.Poller()
...@@ -90,6 +206,109 @@ async def forward_request(url, data, request_id): ...@@ -90,6 +206,109 @@ async def forward_request(url, data, request_id):
yield content yield content
def unique_id_dispatch(prefill_instance : str,
decode_instance : str) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
...@@ -104,24 +323,25 @@ async def handle_request(): ...@@ -104,24 +323,25 @@ async def handle_request():
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_instance = decode_list[count % len(decode_list)]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, " f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]" f"ZMQ:{decode_instance.zmq_address}]"
) )
count += 1 count += 1
request_id = ( request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}" f"{decode_instance.zmq_address}_{random_uuid()}"
) )
# finish prefill # finish prefill
...@@ -151,5 +371,7 @@ async def handle_request(): ...@@ -151,5 +371,7 @@ async def handle_request():
if __name__ == "__main__": if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001) t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001) app.run(host="0.0.0.0", port=10001)
t.join() t.join()
t_1.join()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()
...@@ -123,7 +123,8 @@ class P2pNcclEngine: ...@@ -123,7 +123,8 @@ class P2pNcclEngine:
if self.remote_tp_size % self.tp_size != 0: if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!") logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size) self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
...@@ -203,12 +204,16 @@ class P2pNcclEngine: ...@@ -203,12 +204,16 @@ class P2pNcclEngine:
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
# if port_offset == 0 and self.proxy_address != "": if self.multiple_machines:
if self.proxy_address != "": if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping, self._ping_thread = threading.Thread(target=self._ping,
daemon=True) daemon=True)
self._ping_thread.start() self._ping_thread.start()
else:
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping_new,
daemon=True)
self._ping_thread.start()
logger.info( logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_" "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
...@@ -267,7 +272,7 @@ class P2pNcclEngine: ...@@ -267,7 +272,7 @@ class P2pNcclEngine:
p_ip, p_port = self.parse_request_id(request_id, False) p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port) pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
# remote_port = 22001
if not self.enable_asymmetric_p2p: if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank) remote_address = remote_ip + ":" + str(remote_port + self.rank)
remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size) remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
...@@ -279,7 +284,7 @@ class P2pNcclEngine: ...@@ -279,7 +284,7 @@ class P2pNcclEngine:
remote_pp_rank = self.compute_remote_pp_rank(layer_name) remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = [] items: list[Any] = []
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size): for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp): for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank: if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
...@@ -306,7 +311,7 @@ class P2pNcclEngine: ...@@ -306,7 +311,7 @@ class P2pNcclEngine:
if self.send_type == "PUT": if self.send_type == "PUT":
return all( return all(
self.send_sync(item) for item in self.get_send_queue_items( self._send_sync_new(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla)) request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -628,8 +633,11 @@ class P2pNcclEngine: ...@@ -628,8 +633,11 @@ class P2pNcclEngine:
self.send_stream.wait_event(tbo_evt) self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address) self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else: else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""") if self.multiple_machines:
self._send_sync_new(tensor_id, tensor, remote_address) self._send_sync(tensor_id, tensor, remote_address)
else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
def wait_for_sent(self): def wait_for_sent(self):
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
...@@ -734,7 +742,7 @@ class P2pNcclEngine: ...@@ -734,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id": remote_address.pd_pair_id, "pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank "comm_rank": rank
} }
# logger.info(f"""_send_sync_new:{data}""") logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
response = sock.recv() response = sock.recv()
...@@ -830,18 +838,20 @@ class P2pNcclEngine: ...@@ -830,18 +838,20 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
# sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
# sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
# logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
# sock.connect(f"tcp://{self.proxy_address}") sock.connect(f"tcp://{self.proxy_address}")
# data = { data = {
# "type": "P" if self.config.is_kv_producer else "D", "type": "P" if self.config.is_kv_producer else "D",
# "http_address": self.http_address, "http_address": self.http_address,
# "zmq_address": self.zmq_address "zmq_address": self.zmq_address
# } }
# while True: while True:
# sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
# time.sleep(3) time.sleep(3)
def _ping_new(self):
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
...@@ -856,7 +866,7 @@ class P2pNcclEngine: ...@@ -856,7 +866,7 @@ class P2pNcclEngine:
"pp_size" : self.pp_size, "pp_size" : self.pp_size,
"tp_size" : self.tp_size "tp_size" : self.tp_size
} }
logger.info(f"""_ping data:{data}""") # logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
data = { data = {
"type": "P" if self.config.is_kv_producer else "D", "type": "P" if self.config.is_kv_producer else "D",
...@@ -867,7 +877,7 @@ class P2pNcclEngine: ...@@ -867,7 +877,7 @@ class P2pNcclEngine:
"zmq_address": self.zmq_address "zmq_address": self.zmq_address
} }
# while True: # while True:
logger.info(f"""_ping data:{data}""") # logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
# time.sleep(3) # time.sleep(3)
...@@ -904,10 +914,44 @@ class P2pNcclEngine: ...@@ -904,10 +914,44 @@ class P2pNcclEngine:
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
def get_pp_indices_d(self, num_hidden_layers: int, pp_rank: int,
pp_size: int) -> tuple[int, int]:
partition_list_str = envs.VLLM_PP_LAYER_PARTITION_D
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable",
",".join(str(p) for p in partitions))
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
def compute_remote_pp_rank(self, layer_name: str) -> int: def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name) current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size): for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size) start, end = self.get_pp_indices_d(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""") # logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers): if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1 return self.remote_pp_size - 1
......
...@@ -42,6 +42,7 @@ if TYPE_CHECKING: ...@@ -42,6 +42,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_PP_LAYER_PARTITION_D: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
...@@ -487,6 +488,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -487,6 +488,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PP_LAYER_PARTITION": "VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION_D":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION_D", None),
# (CPU backend only) CPU key-value cache space. # (CPU backend only) CPU key-value cache space.
# default is 4 GiB # default is 4 GiB
"VLLM_CPU_KVCACHE_SPACE": "VLLM_CPU_KVCACHE_SPACE":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment