Commit 4612aad6 authored by Your Name's avatar Your Name
Browse files

[P/D][Feat]支持dp并行

parent cd42bf87
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from typing import Any
import time
from collections import deque, defaultdict
import asyncio
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class Request:
request_id: str
p_http_address: str = ""
p_dp_rank: int = -1
d_http_address: str = ""
d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready and self.p_unique_id != b""
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
running_requests: dict[str, Request] = {}
healthy_instances: dict[str, float] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
instance_cv = threading.Condition()
request_cv = threading.Condition()
health_cv = threading.Condition()
request_queue_cv = threading.Condition()
request_queue: deque[list[Any]] = deque()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address}""")
instance_cv.notify()
elif data["type"] == "P_rank":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D_rank":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_unique_id":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
assert isinstance(data["unique_id"], bytes)
p_instance.p_unique_id = data["unique_id"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P_unique_id {str(p_instance.p_unique_id)} for {p_instance.http_address}""")
elif data["type"] == "heartbeat":
global healthy_instances
global health_cv
with health_cv:
healthy_instances[data["http_address"]] = time.time()
elif data["type"] == "Req":
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global running_requests
global request_cv
with request_cv:
if data["request_id"] in running_requests:
request = running_requests[data["request_id"]]
if data["instance_type"] == "P":
request.p_http_address = data["http_address"]
request.p_dp_rank = int(data["dp_rank"])
elif data["instance_type"] == "D":
request.d_http_address = data["http_address"]
request.d_dp_rank = int(data["dp_rank"])
assert(request.p_dp_rank >= 0 and request.d_dp_rank >=0)
with request_queue_cv:
request_queue.append(request)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv.notify()
else:
if data["instance_type"] == "P":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], p_http_address=data["http_address"], p_dp_rank=int(data["dp_rank"]))
elif data["instance_type"] == "D":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], d_http_address=data["http_address"], d_dp_rank=int(data["dp_rank"]))
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
zmq_context = None
sock_cache : dict[str, Any] = {}
tp_mapping_of_pd_pair : dict[str, dict[int, list[str]]] = {}
tp_comm_mapping_of_pd_pair : dict[str, dict[int, list[int]]] = {}
active_p_tp_rank_of_pd_pair : dict[str, set[int]] = {}
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
def dispatch_to_P(request : Request):
global prefill_instances
global decode_instances
p_ins = prefill_instances[request.p_http_address]
d_ins = decode_instances[request.d_http_address]
global zmq_context
global sock_cache
pd_pair_id = p_ins.http_address + "_" + d_ins.http_address
p_dp_rank = request.p_dp_rank
d_dp_rank = request.d_dp_rank
tp_dst_id = pd_pair_id + "_" + str(d_dp_rank)
assert(d_ins.pp_size == 1)
d_pp_rank = 0
global tp_mapping_of_pd_pair
global tp_comm_mapping_of_pd_pair
global active_p_tp_rank_of_pd_pair
if tp_dst_id not in active_p_tp_rank_of_pd_pair:
p_active_tp_rank = set()
p_tp_rank_to_dst : dict[int, list[str]] = defaultdict(list)
p_tp_rank_to_dst_comm : dict[int, list[int]] = defaultdict(list)
for d_tp_rank in range(d_ins.tp_size):
p_tp_rank = d_tp_rank % p_ins.tp_size
p_active_tp_rank.add(p_tp_rank)
p_tp_rank_to_dst[p_tp_rank].append(d_ins.rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
p_tp_rank_to_dst_comm[p_tp_rank].append(d_ins.comm_rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
tp_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair[tp_dst_id] = p_active_tp_rank
p_active_tp_rank = active_p_tp_rank_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst = tp_mapping_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst_comm = tp_comm_mapping_of_pd_pair[tp_dst_id]
for p_pp_rank in range(p_ins.pp_size):
for p_tp_rank in p_active_tp_rank:
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_to_transfer",
"request_id": request.request_id,
"dst_num": len(p_tp_rank_to_dst[p_tp_rank]),
"pd_pair_id": pd_pair_id,
"remote_address": p_tp_rank_to_dst[p_tp_rank],
"remote_rank": p_tp_rank_to_dst_comm[p_tp_rank],
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
logger.info(f"""[Router] dispatch Request {request.request_id} [{p_dp_rank}, {p_pp_rank}, {p_tp_rank}] -> [{d_dp_rank}, {d_pp_rank}]""")
for p_tp_rank in range(p_ins.tp_size):
if p_tp_rank not in p_active_tp_rank:
for p_pp_rank in range(p_ins.pp_size):
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_not_to_transfer",
"request_id": request.request_id,
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
def dp_dispatch():
global request_queue_cv
global request_queue
while True:
with request_queue_cv:
while not request_queue:
request_queue_cv.wait()
request = request_queue.pop()
dispatch_to_P(request)
def start_dp_dispatch():
_thread = threading.Thread(
target=dp_dispatch, daemon=True
)
_thread.start()
return _thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def unique_id_dispatch(prefill_instance : Instance,
decode_instance : Instance) :
global zmq_context
global sock_cache
global pd_pair
global router_nccl
pd_pair_id = prefill_instance.http_address + "_" + decode_instance.http_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
rank = 0
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global instance_cv
with instance_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
with instance_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_instance = decode_list[count % len(decode_list)]
# TODO: Init nccl comm : dispatch unique_id among PD pair ranks
global pd_pair
if prefill_instance.http_address + "_" + decode_instance.http_address not in pd_pair:
unique_id_dispatch(prefill_instance, decode_instance)
logger.info(
f"handle_request count: {count}, [HTTP:{prefill_addr}, 👉 HTTP:{decode_addr}]"
)
count += 1
request_id = f"{random_uuid()}"
async def run_prefill():
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
pass
prefill_task = asyncio.create_task(run_prefill())
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_dp_dispatch()
t_3 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
t_3.join()
...@@ -54,6 +54,7 @@ class KVConnectorFactory: ...@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1: ) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, " raise ValueError("Attempting to initialize a V1 Connector, "
...@@ -81,7 +82,7 @@ class KVConnectorFactory: ...@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role, dp_rank)
# Register various connectors here. # Register various connectors here.
......
...@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -6,19 +6,23 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
import os import os
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine) P2pNcclEngine, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group, get_dp_group, get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
import zmq
import msgpack
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata): ...@@ -78,7 +82,7 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class P2pNcclConnector(KVConnectorBase_V1): class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
...@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -102,12 +106,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \ self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0 if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine( self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank, local_rank=self._local_rank,
port_offset=self._rank,
config=self.config, config=self.config,
model_config=vllm_config.model_config, hostname="",
port_offset=self._rank,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None ) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
...@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -117,19 +126,9 @@ class P2pNcclConnector(KVConnectorBase_V1):
self.pp_size = self.parallel_config.pipeline_parallel_size self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size self.num_card = self.pp_size * self.tp_size
self.multiple_machines = 1 if self.num_card > 8 else 0
self.remote_tp_size = self.config.get_from_extra_config( if self.is_producer and self.multiple_machines == 1:
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {} self.ip_map = {}
self.duplicate_keys = [] self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE') config_file = os.getenv('IP_CONFIG_FILE')
...@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -152,10 +151,38 @@ class P2pNcclConnector(KVConnectorBase_V1):
print(f"Error: Exception occurred while reading configuration file - {e}") print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key): def get_ip_value(self, key):
return self.ip_map.get(key) return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ============================== # ==============================
# Worker-side methods # Worker-side methods
# ============================== # ==============================
...@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -304,7 +331,13 @@ class P2pNcclConnector(KVConnectorBase_V1):
2, num_pages * page_size, -1) 2, num_pages * page_size, -1)
inject_start_index = 0 inject_start_index = 0
for num in range(self.p2p_nccl_engine.tensor_split_num): req_layer = f"{request.request_id}#{layer_name}"
with self.p2p_nccl_engine.recv_store_cv:
while req_layer not in self.p2p_nccl_engine.recv_split_nums:
self.p2p_nccl_engine.recv_store_cv.wait()
split_num = self.p2p_nccl_engine.recv_split_nums.get(req_layer)
for num in range(split_num):
kv_cache = self.p2p_nccl_engine.recv_tensor( kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num)) request.request_id + "#" + layer_name + "#" + str(num))
...@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -332,6 +365,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
# inject_kv_into_layer(kv_cache_layer, kv_cache, # inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id) # request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num) tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.p2p_nccl_engine.recv_store: if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None) tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop( self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
...@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -375,8 +409,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None assert self.p2p_nccl_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer( def extract_kv_from_layer(
layer: torch.Tensor, layer: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
...@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -400,8 +432,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC: if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping slot_mapping = request.slot_mapping
if request.slot_mapping_device is None: if request.slot_mapping_device is None:
request.slot_mapping_device = \ request.slot_mapping_device = \
...@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -409,91 +439,46 @@ class P2pNcclConnector(KVConnectorBase_V1):
slot_mapping = request.slot_mapping_device slot_mapping = request.slot_mapping_device
tbo_evt = torch.cuda.Event(enable_timing=False) tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record() tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \ pending = False
self.parallel_config.pipeline_parallel_size with self.p2p_nccl_engine.req_status_cv:
if (self.pp_size == 1): if request_id not in self.p2p_nccl_engine.req_status:
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, pending = True
(kv_layer, slot_mapping), remote_address, tbo_evt) if pending:
elif (self.pp_size == 2): self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
if (pp_rank == 0): (kv_layer, slot_mapping), tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
(kv_layer, slot_mapping), remote_address, tbo_evt) else :
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, req_data = self.p2p_nccl_engine.req_status[request_id]
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt) assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
else: for i in range(req_data.dst_num):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt) (kv_layer, slot_mapping), remote_addr, tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!") # self.p2p_nccl_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
# (kv_layer, slot_mapping), remote_address, tbo_evt)
else: else:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size pending = False
) % self.parallel_config.pipeline_parallel_size with self.p2p_nccl_engine.req_status_cv:
if (self.multiple_machines_p and self.multiple_machines_d): if request_id not in self.p2p_nccl_engine.req_status:
ip_second = self.get_ip_value(ip) pending = True
if (self.pp_size == 1): if pending:
if self._rank < 8: self.p2p_nccl_engine.pending_tensor(request_id, layer_name,
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, kv_cache)
kv_cache, remote_address) logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, else :
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8)) req_data = self.p2p_nccl_engine.req_status[request_id]
elif (self.pp_size == 2): assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
if (pp_rank == 0): for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_addr)
else:
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
self.p2p_nccl_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else:
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank - 4))
# elif (self.pp_size == 8):
# for i in range(8):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else:
logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
...@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -612,9 +597,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens) num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
continue
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
if not resumed_from_preemption: if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids) block_ids = (self.chunked_prefill[req_id][0] + block_ids)
......
...@@ -6,14 +6,14 @@ import os ...@@ -6,14 +6,14 @@ import os
import threading import threading
import time import time
import typing import typing
from collections import deque from collections import deque, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack import msgpack
import torch import torch
import zmq import zmq
import regex
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
...@@ -24,11 +24,6 @@ from vllm.utils import current_stream, get_ip ...@@ -24,11 +24,6 @@ from vllm.utils import current_stream, get_ip
from vllm import envs from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -36,11 +31,6 @@ logger = logging.getLogger(__name__) ...@@ -36,11 +31,6 @@ logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32 DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager @contextmanager
def set_p2p_nccl_context(num_channels: str): def set_p2p_nccl_context(num_channels: str):
...@@ -71,44 +61,50 @@ def set_p2p_nccl_context(num_channels: str): ...@@ -71,44 +61,50 @@ def set_p2p_nccl_context(num_channels: str):
else: else:
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
zmq_address: str = ""
comm_rank: int = 0
class P2pNcclEngine: class P2pNcclEngine:
def __init__(self, def __init__(self,
local_rank: int, local_rank: int,
port_offset: int,
config: KVTransferConfig, config: KVTransferConfig,
model_config: ModelConfig, hostname: str = "",
port_offset: int = 0,
dp_rank: int = 0,
pp_rank: int = 0,
tp_rank: int = 0,
dp_size: int = 0,
pp_size: int = 0,
tp_size: int = 0,
library_path: Optional[str] = None) -> None: library_path: Optional[str] = None) -> None:
self.config = config self.config = config
self.model_config = model_config
self.rank = port_offset self.rank = port_offset
self.local_rank = local_rank self.local_rank = local_rank
self.dp_rank = dp_rank
self.pp_rank = pp_rank
self.tp_rank = tp_rank
self.dp_size = dp_size
self.pp_size = pp_size
self.tp_size = tp_size
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config, if not hostname:
"num_hidden_layers", 0) hostname = get_ip()
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
self._hostname = get_ip() self._hostname = hostname
self._port = port self._port = port
# Each card corresponds to a ZMQ address. # Each card corresponds to a ZMQ address.
...@@ -116,7 +112,7 @@ class P2pNcclEngine: ...@@ -116,7 +112,7 @@ class P2pNcclEngine:
# The `http_port` must be consistent with the port of OpenAI. # The `http_port` must be consistent with the port of OpenAI.
self.http_address = ( self.http_address = (
f"{self._hostname}:" f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}") f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`, # If `proxy_ip` or `proxy_port` is `""`,
...@@ -128,16 +124,27 @@ class P2pNcclEngine: ...@@ -128,16 +124,27 @@ class P2pNcclEngine:
else: else:
self.proxy_address = proxy_ip + ":" + proxy_port self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context() self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER) self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}") self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition() self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition() self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
...@@ -145,6 +152,7 @@ class P2pNcclEngine: ...@@ -145,6 +152,7 @@ class P2pNcclEngine:
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None self.p2p_async_buf = None
self.tensor_split_num: int = 0 self.tensor_split_num: int = 0
self.recv_split_nums: dict[str, int] = {}
mem_pool_size_gb = self.config.get_from_extra_config( mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
...@@ -161,11 +169,16 @@ class P2pNcclEngine: ...@@ -161,11 +169,16 @@ class P2pNcclEngine:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self._send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape) # tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {} self.recv_store: dict[str, Any] = {}
...@@ -184,7 +197,7 @@ class P2pNcclEngine: ...@@ -184,7 +197,7 @@ class P2pNcclEngine:
self._listener_thread.start() self._listener_thread.start()
self._ping_thread = None self._ping_thread = None
if port_offset == 0 and self.proxy_address != "": if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping, self._ping_thread = threading.Thread(target=self._ping,
daemon=True) daemon=True)
self._ping_thread.start() self._ping_thread.start()
...@@ -198,92 +211,24 @@ class P2pNcclEngine: ...@@ -198,92 +211,24 @@ class P2pNcclEngine:
def _create_connect(self, remote_address: typing.Optional[str] = None): def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None assert remote_address is not None
if remote_address not in self.socks: if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt(zmq.SNDHWM, 10000)
sock.setsockopt(zmq.RCVHWM, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
sock.setsockopt_string(zmq.IDENTITY, f"P-{self.zmq_address}")
sock.connect(f"tcp://{remote_address}") sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.nccl.ncclGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
rank = 0
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address] return self.socks[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
return [(tensor_id, remote_address, tensor)]
if not is_mla:
logger.error(" P2PNCCL only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
up_down = 1
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
logger.debug(
"📥 [PUT] Wait to send: tensor_id:%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank)
items.append([tensor_id, remote_address, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" P2PNCCL new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor( def send_tensor(
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
...@@ -310,7 +255,7 @@ class P2pNcclEngine: ...@@ -310,7 +255,7 @@ class P2pNcclEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank) self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
...@@ -318,7 +263,7 @@ class P2pNcclEngine: ...@@ -318,7 +263,7 @@ class P2pNcclEngine:
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape, remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size, self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
...@@ -328,13 +273,14 @@ class P2pNcclEngine: ...@@ -328,13 +273,14 @@ class P2pNcclEngine:
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None, tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
with self.recv_store_cv: with self.recv_store_cv:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify() # self.recv_store_cv.notify()
self.recv_store_cv.notify_all()
return True return True
else: else:
if self.send_type == "PUT": if self.send_type == "PUT":
...@@ -343,7 +289,7 @@ class P2pNcclEngine: ...@@ -343,7 +289,7 @@ class P2pNcclEngine:
with self.send_queue_cv: with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping) kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt]) self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify() self.send_queue_cv.notify_all()
else: # GET else: # GET
with self.send_store_cv: with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel() tensor_size = tensor.element_size() * tensor.numel()
...@@ -357,7 +303,7 @@ class P2pNcclEngine: ...@@ -357,7 +303,7 @@ class P2pNcclEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank) self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
...@@ -365,12 +311,62 @@ class P2pNcclEngine: ...@@ -365,12 +311,62 @@ class P2pNcclEngine:
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape, remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size, self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
return True return True
# TODO : support p2p async
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor( def recv_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -407,6 +403,7 @@ class P2pNcclEngine: ...@@ -407,6 +403,7 @@ class P2pNcclEngine:
self._create_connect(remote_address) self._create_connect(remote_address)
sock = self.socks[remote_address] sock = self.socks[remote_address]
# TODO: self.comms has changed along with PUT mode
comm, rank = self.comms[remote_address] comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id} data = {"cmd": "GET", "tensor_id": tensor_id}
...@@ -429,26 +426,23 @@ class P2pNcclEngine: ...@@ -429,26 +426,23 @@ class P2pNcclEngine:
def _listen_for_requests(self): def _listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll(5000))
if self.router_socket in socks: if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart() remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes( logger.info(f"unexpected message from {remote_address.decode()}")
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data: if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"] # self.tensor_split_num = data["tensor_split_num"]
parts = tensor_id.split('#')
request_id = parts[0]
layer_name = parts[1]
req_layer = f"{request_id}#{layer_name}"
self.recv_split_nums[req_layer] = data["tensor_split_num"]
with self.recv_store_cv:
self.recv_store_cv.notify_all()
try: try:
with torch.cuda.stream(self.recv_stream): with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"], tensor = torch.empty(data["shape"],
...@@ -457,8 +451,8 @@ class P2pNcclEngine: ...@@ -457,8 +451,8 @@ class P2pNcclEngine:
device=self.device) device=self.device)
self.router_socket.send_multipart( self.router_socket.send_multipart(
[remote_address, b"0"]) [remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()] comm, rank = self.comms[data["pd_pair_id"]]
self._recv(comm, tensor, rank ^ 1, self.recv_stream) self._recv(comm, tensor, int(data["comm_rank"]), self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel() tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size if (self.buffer_size + tensor_size
> self.buffer_size_threshold): > self.buffer_size_threshold):
...@@ -480,7 +474,8 @@ class P2pNcclEngine: ...@@ -480,7 +474,8 @@ class P2pNcclEngine:
with self.recv_store_cv: with self.recv_store_cv:
self.recv_store[tensor_id] = tensor self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id) self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify() #self.recv_store_cv.notify()
self.recv_store_cv.notify_all()
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
...@@ -503,9 +498,35 @@ class P2pNcclEngine: ...@@ -503,9 +498,35 @@ class P2pNcclEngine:
[remote_address, msgpack.dumps(data)]) [remote_address, msgpack.dumps(data)])
if data["ret"] == 0: if data["ret"] == 0:
# TODO: self.comms has changed along with PUT mode
comm, rank = self.comms[remote_address.decode()] comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1, self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream) self.send_stream)
elif data["cmd"] == "comm_init":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = int(data["rank"])
world_size = int(data["world_size"])
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
world_size, unique_id, rank)
self.comms[data["pd_pair_id"]] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
else: else:
logger.warning( logger.warning(
"🚧Unexpected, Received message from %s, data:%s", "🚧Unexpected, Received message from %s, data:%s",
...@@ -533,7 +554,7 @@ class P2pNcclEngine: ...@@ -533,7 +554,7 @@ class P2pNcclEngine:
else: else:
tensor_id, remote_address, tensor = self.send_queue.popleft() tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue: if not self.send_queue:
self.send_queue_cv.notify() self.send_queue_cv.notify_all()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None: if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt) self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address) self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
...@@ -551,12 +572,13 @@ class P2pNcclEngine: ...@@ -551,12 +572,13 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank) " to be empty, rank:%d", duration * 1000, self.rank)
# TODO : support p2p async
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor, def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool: slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks: if remote_address.zmq_address not in self.socks:
self._create_connect(remote_address) self._create_connect(remote_address.zmq_address)
sock = self.socks[remote_address] sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address] comm, rank = self.comms[remote_address.pd_pair_id]
is_mla = (kv_layer.ndim == 3) is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1] hidden_dim = kv_layer.shape[-1]
...@@ -600,20 +622,22 @@ class P2pNcclEngine: ...@@ -600,20 +622,22 @@ class P2pNcclEngine:
"pack_idx": pack_idx, "pack_idx": pack_idx,
"tensor_split_num": pack_num, "tensor_split_num": pack_num,
"shape": tx_shape, "shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "") "dtype": str(kv_layer.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
} }
sock.send(msgpack.dumps(header)) sock.send(msgpack.dumps(header))
response = sock.recv() response = sock.recv()
if response != b"0": if response != b"0":
logger.error( logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s", "🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank, self.zmq_address, remote_address.zmq_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3, tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode() response.decode()
) )
return False return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream) self._send(comm, send_tensor, remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id) self._have_sent_tensor_id(tensor_id)
...@@ -624,20 +648,22 @@ class P2pNcclEngine: ...@@ -624,20 +648,22 @@ class P2pNcclEngine:
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
return False return False
if remote_address not in self.socks: if remote_address.zmq_address not in self.socks:
self._create_connect(remote_address) self._create_connect(remote_address.zmq_address)
sock = self.socks[remote_address] sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address] comm, rank = self.comms[remote_address.pd_pair_id]
data = { data = {
"cmd": "PUT", "cmd": "PUT",
"tensor_id": tensor_id, "tensor_id": tensor_id,
"shape": tensor.shape, "shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "") "dtype": str(tensor.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
} }
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
...@@ -646,12 +672,12 @@ class P2pNcclEngine: ...@@ -646,12 +672,12 @@ class P2pNcclEngine:
logger.error( logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape, self.zmq_address, remote_address.zmq_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3, tensor.element_size() * tensor.numel() / 1024**3,
response.decode()) response.decode())
return False return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) self._send(comm, tensor.to(self.device), remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id) self._have_sent_tensor_id(tensor_id)
...@@ -673,20 +699,46 @@ class P2pNcclEngine: ...@@ -673,20 +699,46 @@ class P2pNcclEngine:
""" """
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
# for request_id in finished_req_ids:
# for layer_name in forward_context.no_compile_layers:
# tensor_id = request_id + "#" + layer_name
# if tensor_id in self.recv_store:
# with self.recv_store_cv:
# tensor = self.recv_store.pop(tensor_id, None)
# self.send_request_id_to_tensor_ids.pop(
# request_id, None)
# self.recv_request_id_to_tensor_ids.pop(
# request_id, None)
# addr = 0
# if isinstance(tensor, tuple):
# addr, _, _ = tensor
# self.pool.free(addr)
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: with self.pending_queue_cv:
tensor_id = request_id + "#" + layer_name if request_id in self.pending_queue:
if tensor_id in self.recv_store: self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv: with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None) tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache. # TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set() finished_sending: set[str] = set()
...@@ -694,19 +746,64 @@ class P2pNcclEngine: ...@@ -694,19 +746,64 @@ class P2pNcclEngine:
# TODO:Retrieve requests that have already received the KV cache. # TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set() finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
sock = self.context.socket(zmq.DEALER) sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) sock.setsockopt_string(zmq.IDENTITY, f"{self.zmq_address}_ping")
logger.debug("ping start, zmq_address:%s", self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}") sock.connect(f"tcp://{self.proxy_address}")
if self.rank == 0:
data = { data = {
"type": "P" if self.config.is_kv_producer else "D", "type": "P_init" if self.config.is_kv_producer else "D_init",
"http_address": self.http_address, "http_address": self.http_address,
"zmq_address": self.zmq_address "dp_size" : self.dp_size,
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
} }
sock.send(msgpack.dumps(data))
data = {
"type": "P_rank" if self.config.is_kv_producer else "D_rank",
"http_address": self.http_address,
"dp_rank" : self.dp_rank,
"pp_rank" : self.pp_rank,
"tp_rank" : self.tp_rank,
"zmq_address" : self.zmq_address
}
sock.send(msgpack.dumps(data))
if self.rank != 0:
return
if self.config.is_kv_producer:
unique_id = self.nccl.ncclGetUniqueId()
data = {
"type": "P_unique_id",
"http_address": self.http_address,
"unique_id": bytes(unique_id.internal)
}
sock.send(msgpack.dumps(data))
while True: while True:
data = {
"type": "heartbeat",
"http_address": self.http_address,
}
sock.send(msgpack.dumps(data)) sock.send(msgpack.dumps(data))
time.sleep(3) time.sleep(3)
...@@ -740,40 +837,6 @@ class P2pNcclEngine: ...@@ -740,40 +837,6 @@ class P2pNcclEngine:
self._listener_thread.join() self._listener_thread.join()
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread.join() self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
\ No newline at end of file
...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface): ...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
"with KV connectors") "with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
...@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface): ...@@ -371,8 +372,10 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if request.is_finished():
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
...@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface): ...@@ -457,6 +460,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \ if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget: num_new_tokens > token_budget:
break
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
...@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface): ...@@ -668,6 +672,11 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface): ...@@ -751,6 +760,7 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked # pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \ if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget: num_new_tokens > token_budget:
break
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
...@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface): ...@@ -1311,7 +1321,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
......
...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore): ...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str], def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str], coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment