Commit d916e714 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-pdncclinit' into 'v0.9.2-dev'

V0.9.2 dev pdncclinit

See merge request dcutoolkit/deeplearing/vllm!333
parents 1b5aa25e 294cc23a
......@@ -9,15 +9,82 @@ import uuid
import aiohttp
import msgpack
import zmq
from typing import Any
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
instance_cv = threading.Condition()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket):
while True:
......@@ -27,16 +94,61 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
else:
print(
"Unexpected, Received message from %s, data: %s",
......@@ -44,15 +156,19 @@ def _listen_for_register(poller, router_socket):
data,
)
zmq_context = None
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
......@@ -90,6 +206,109 @@ async def forward_request(url, data, request_id):
yield content
def unique_id_dispatch(prefill_instance : str,
decode_instance : str) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
......@@ -104,24 +323,25 @@ async def handle_request():
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_addr, decode_instance = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_instance.zmq_address}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_instance.zmq_address}_{random_uuid()}"
)
# finish prefill
......@@ -151,5 +371,7 @@ async def handle_request():
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()
......@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
......@@ -108,6 +108,12 @@ class P2pNcclConnector(KVConnectorBase_V1):
port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
......@@ -177,13 +183,11 @@ class P2pNcclConnector(KVConnectorBase_V1):
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
......@@ -274,7 +278,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name
......@@ -436,7 +439,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
......@@ -467,6 +472,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
......
......@@ -72,6 +72,13 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
zmq_address: str = ""
comm_rank: int = 0
class P2pNcclEngine:
def __init__(self,
......@@ -79,11 +86,23 @@ class P2pNcclEngine:
port_offset: int,
config: KVTransferConfig,
model_config: ModelConfig,
dp_rank: int = 0,
pp_rank: int = 0,
tp_rank: int = 0,
dp_size: int = 0,
pp_size: int = 0,
tp_size: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.model_config = model_config
self.rank = port_offset
self.local_rank = local_rank
self.dp_rank = dp_rank
self.pp_rank = pp_rank
self.tp_rank = tp_rank
self.dp_size = dp_size
self.pp_size = pp_size
self.tp_size = tp_size
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
......@@ -104,7 +123,8 @@ class P2pNcclEngine:
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
......@@ -184,11 +204,16 @@ class P2pNcclEngine:
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread.start()
if self.multiple_machines:
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread.start()
else:
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping_new,
daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
......@@ -196,6 +221,21 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect_new(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt(zmq.SNDHWM, 10000)
sock.setsockopt(zmq.RCVHWM, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
sock.setsockopt_string(zmq.IDENTITY, f"P-{self.zmq_address}")
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
return self.socks[remote_address]
def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
......@@ -228,30 +268,36 @@ class P2pNcclEngine:
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)]
remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return [(tensor_id, remote_addr, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
remote_addr = RemoteAddr(pd_pair_id, remote_address, remote_port_offset + self.pp_size * self.tp_size)
logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id,
"Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank)
remote_pp_rank, self.rank * mul_tp + self.rank, self.rank, remote_port_offset + self.pp_size * self.tp_size)
items.append([tensor_id, remote_address, tensor])
items.append([tensor_id, remote_addr, tensor])
return items
def send_tensor_new(
......@@ -265,7 +311,7 @@ class P2pNcclEngine:
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
self._send_sync_new(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
......@@ -481,7 +527,56 @@ class P2pNcclEngine:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "PUT_NEW":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm, rank = self.comms[data["pd_pair_id"]]
self._recv(comm, tensor, int(data["comm_rank"]), self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "comm_init":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = int(data["rank"])
world_size = int(data["world_size"])
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
world_size, unique_id, rank)
self.comms[data["pd_pair_id"]] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
......@@ -538,7 +633,11 @@ class P2pNcclEngine:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
self._send_sync(tensor_id, tensor, remote_address)
if self.multiple_machines:
self._send_sync(tensor_id, tensor, remote_address)
else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
......@@ -620,6 +719,48 @@ class P2pNcclEngine:
return True
def _send_sync_new(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[RemoteAddr] = None,
) -> bool:
if remote_address is None:
return False
if remote_address.zmq_address not in self.socks:
# logger.info(f"""=============xiabo remote_address.zmq_address:{remote_address.zmq_address}""")
self._create_connect_new(remote_address.zmq_address)
sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address.pd_pair_id]
data = {
"cmd": "PUT_NEW",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address.zmq_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync(
self,
tensor_id: str,
......@@ -710,6 +851,36 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _ping_new(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
if self.rank == 0:
data = {
"type": "P_init" if self.config.is_kv_producer else "D_init",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
"dp_size" : self.dp_size,
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
}
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"dp_rank" : self.dp_rank,
"pp_rank" : self.pp_rank,
"tp_rank" : self.tp_rank,
"zmq_address": self.zmq_address
}
# while True:
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
# time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
......@@ -743,11 +914,45 @@ class P2pNcclEngine:
if self._ping_thread is not None:
self._ping_thread.join()
def get_pp_indices_d(self, num_hidden_layers: int, pp_rank: int,
pp_size: int) -> tuple[int, int]:
partition_list_str = envs.VLLM_PP_LAYER_PARTITION_D
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable",
",".join(str(p) for p in partitions))
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
start, end = self.get_pp_indices_d(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
......
......@@ -122,7 +122,7 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
logger.debug(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable",
......
......@@ -42,6 +42,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_PP_LAYER_PARTITION_D: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
......@@ -490,6 +491,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION_D":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION_D", None),
# (CPU backend only) CPU key-value cache space.
# default is 4 GiB
"VLLM_CPU_KVCACHE_SPACE":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment