Commit 61ba33d5 authored by xuxz's avatar xuxz Committed by xuxz
Browse files

[PD][Feat]支持pd分离dp并行

parent ce47a56e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from typing import Any
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
import time
import asyncio
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class Request:
request_id: str
p_http_address: str = ""
p_dp_rank: int = -1
d_http_address: str = ""
d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
running_requests: dict[str, Request] = {}
healthy_instances: dict[str, float] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
instance_cv = threading.Condition()
request_cv = threading.Condition()
health_cv = threading.Condition()
request_queue_cv = threading.Condition()
request_queue: deque[list[Any]] = deque()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "heartbeat":
global healthy_instances
global health_cv
with health_cv:
healthy_instances[data["http_address"]] = time.time()
elif data["type"] == "Req":
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global running_requests
global request_cv
with request_cv:
if data["request_id"] in running_requests:
request = running_requests[data["request_id"]]
if data["instance_type"] == "P":
request.p_http_address = data["http_address"]
request.p_dp_rank = int(data["dp_rank"])
elif data["instance_type"] == "D":
request.d_http_address = data["http_address"]
request.d_dp_rank = int(data["dp_rank"])
assert(request.p_dp_rank >= 0 and request.d_dp_rank >=0)
with request_queue_cv:
request_queue.append(request)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv.notify()
else:
if data["instance_type"] == "P":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], p_http_address=data["http_address"], p_dp_rank=int(data["dp_rank"]))
elif data["instance_type"] == "D":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], d_http_address=data["http_address"], d_dp_rank=int(data["dp_rank"]))
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
zmq_context = None
tp_mapping_of_pd_pair : dict[str, dict[int, list[str]]] = {}
tp_comm_mapping_of_pd_pair : dict[str, dict[int, list[int]]] = {}
active_p_tp_rank_of_pd_pair : dict[str, set[int]] = {}
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
def dispatch_to_P(request : Request):
global prefill_instances
global decode_instances
p_ins = prefill_instances[request.p_http_address]
d_ins = decode_instances[request.d_http_address]
global zmq_context
global sock_cache
pd_pair_id = p_ins.http_address + "_" + d_ins.http_address
p_dp_rank = request.p_dp_rank
d_dp_rank = request.d_dp_rank
tp_dst_id = pd_pair_id + "_" + str(d_dp_rank)
assert(d_ins.pp_size == 1)
d_pp_rank = 0
global tp_mapping_of_pd_pair
global tp_comm_mapping_of_pd_pair
global active_p_tp_rank_of_pd_pair
if tp_dst_id not in active_p_tp_rank_of_pd_pair:
p_active_tp_rank = set()
p_tp_rank_to_dst : dict[int, list[str]] = defaultdict(list)
p_tp_rank_to_dst_comm : dict[int, list[int]] = defaultdict(list)
for d_tp_rank in range(d_ins.tp_size):
p_tp_rank = d_tp_rank % p_ins.tp_size
p_active_tp_rank.add(p_tp_rank)
p_tp_rank_to_dst[p_tp_rank].append(d_ins.rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
p_tp_rank_to_dst_comm[p_tp_rank].append(d_ins.comm_rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
tp_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair[tp_dst_id] = p_active_tp_rank
p_active_tp_rank = active_p_tp_rank_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst = tp_mapping_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst_comm = tp_comm_mapping_of_pd_pair[tp_dst_id]
for p_pp_rank in range(p_ins.pp_size):
for p_tp_rank in p_active_tp_rank:
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_to_transfer",
"request_id": request.request_id,
"dst_num": len(p_tp_rank_to_dst[p_tp_rank]),
"pd_pair_id": pd_pair_id,
"remote_address": p_tp_rank_to_dst[p_tp_rank],
"remote_rank": p_tp_rank_to_dst_comm[p_tp_rank],
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
logger.info(f"""[Router] dispatch Request {request.request_id} [{p_dp_rank}, {p_pp_rank}, {p_tp_rank}] -> [{d_dp_rank}, {d_pp_rank}]""")
for p_tp_rank in range(p_ins.tp_size):
if p_tp_rank not in p_active_tp_rank:
for p_pp_rank in range(p_ins.pp_size):
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_not_to_transfer",
"request_id": request.request_id,
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
def dp_dispatch():
global request_queue_cv
global request_queue
while True:
with request_queue_cv:
while not request_queue:
request_queue_cv.wait()
request = request_queue.pop()
dispatch_to_P(request)
def start_dp_dispatch():
_thread = threading.Thread(
target=dp_dispatch, daemon=True
)
_thread.start()
return _thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
def unique_id_dispatch(prefill_instance : Instance,
decode_instance : Instance) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.http_address + "_" + decode_instance.http_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global instance_cv
with instance_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
with instance_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_instance = decode_list[count % len(decode_list)]
global pd_pair
if prefill_instance.http_address + "_" + decode_instance.http_address not in pd_pair:
raise RuntimeError("Selected PD pair was not inited")
logger.info(
f"handle_request count: {count}, [HTTP:{prefill_addr}, 👉 HTTP:{decode_addr}]"
)
count += 1
request_id = f"{random_uuid()}"
async def run_prefill():
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
pass
prefill_task = asyncio.create_task(run_prefill())
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
t_2 = start_dp_dispatch()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
t_2.join()
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
from vllm import envs
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, cast from typing import TYPE_CHECKING, Optional, cast
...@@ -45,6 +46,7 @@ class KVConnectorFactory: ...@@ -45,6 +46,7 @@ class KVConnectorFactory:
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: Optional["KVCacheConfig"] = None,
dp_rank: int = -1,
) -> KVConnectorBase: ) -> KVConnectorBase:
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None: if kv_transfer_config is None:
...@@ -77,6 +79,8 @@ class KVConnectorFactory: ...@@ -77,6 +79,8 @@ class KVConnectorFactory:
if compat_sig: if compat_sig:
# Old signature: __init__(self, vllm_config, role) # Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role) return connector_cls(config, role)
elif envs.VLLM_USE_DP_CONNECTOR:
return connector_cls(config, role, kv_cache_config, dp_rank)
else: else:
# New signature: __init__(self, vllm_config, role, kv_cache_config) # New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config) return connector_cls(config, role, kv_cache_config)
...@@ -160,6 +164,11 @@ KVConnectorFactory.register_connector( ...@@ -160,6 +164,11 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector", "vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector",
"DuSwiftConnector") "DuSwiftConnector")
KVConnectorFactory.register_connector(
"DuSwiftConnectorDp",
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector_dp",
"DuSwiftConnectorDp")
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"LMCacheConnectorV1", "LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
import os
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_engine_dp import (
DuSwiftEngineDp, RemoteAddr)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
slot_mapping_device: torch.Tensor = None
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
request_id=request_id,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
)
@dataclass
class DuSwiftConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
class DuSwiftConnectorDp(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config,)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank \
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self._dp_rank = get_dp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._pp_rank = get_pp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._tp_rank = get_tp_group().rank_in_group \
if role == KVConnectorRole.WORKER else 0
self._dp_size = get_dp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._pp_size = get_pp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self._tp_size = get_tp_group().world_size \
if role == KVConnectorRole.WORKER else 0
self.du_swift_engine = DuSwiftEngineDp(
local_rank=self._local_rank,
port_offset=self._rank,
config=self.config,
model_config=vllm_config.model_config,
dp_rank=self._dp_rank,
pp_rank=self._pp_rank,
tp_rank=self._tp_rank,
dp_size=self._dp_size,
pp_size=self._pp_size,
tp_size=self._tp_size
) if role == KVConnectorRole.WORKER else None
self.parallel_config = vllm_config.parallel_config
self.model_config = vllm_config.model_config
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_size = self.parallel_config.pipeline_parallel_size
self.tp_size = self.parallel_config.tensor_parallel_size
self.num_card = self.pp_size * self.tp_size
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", self.tp_size)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", self.pp_size)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
self.remote_num_card = self.remote_tp_size * self.remote_pp_size
self.multiple_machines_d = 1 if self.remote_num_card > 8 else 0
self.multiple_machines_p = 1 if self.num_card > 8 else 0
if self.is_producer and self.multiple_machines_p == 1:
self.ip_map = {}
self.duplicate_keys = []
config_file = os.getenv('IP_CONFIG_FILE')
if not config_file:
print("Warning: Please set the IPVNet FILE environment variable for cross machine recognition of the second IP address")
return
try:
with open(config_file, 'r', encoding='utf-8') as file:
for line_num, line in enumerate(file, 1):
line = line.strip()
if line and not line.startswith('#'):
ips = line.split()
if len(ips) == 2:
first_ip, second_ip = ips
if first_ip not in self.ip_map:
self.ip_map[first_ip] = second_ip
else:
print(f"warning: num {line_num} Incorrect format : {line}")
except Exception as e:
print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key):
return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.du_swift_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
request_id: str,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
request_id (str): request id for log
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()) or dst_kv_cache_layer.ndim == 3:
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
else:
dst_kv_cache_layer[slot_mapping[:num_token],
...] = src_kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
else:
dst_kv_cache_layer[:, slot_mapping[:num_token],
...] = src_kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, DuSwiftConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
kv_cache_layer = kv_cache[ \
forward_context.virtual_engine]
if not envs.VLLM_P2P_ASYNC:
kv_cache = self.du_swift_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name
if tensor_id in self.du_swift_engine.recv_store:
tensor = self.du_swift_engine.recv_store.pop(tensor_id, None)
self.du_swift_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.du_swift_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.du_swift_engine.pool.free(addr)
else:
dst_kv_cache_layer_shape = kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
num_pages * page_size, -1)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
assert kv_cache_layer.is_contiguous()
dst_kv_cache_layer = kv_cache_layer.reshape(
2, num_pages * page_size, -1)
inject_start_index = 0
for num in range(self.du_swift_engine.tensor_split_num):
kv_cache = self.du_swift_engine.recv_tensor(
request.request_id + "#" + layer_name + "#" + str(num))
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
if isinstance(attn_metadata, MLACommonMetadata) or all(isinstance(value, MLACommonMetadata) for value in attn_metadata.values()):
num_token = kv_cache.shape[0]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
else:
num_token = kv_cache.shape[1]
if len(request.slot_mapping) == num_token:
dst_kv_cache_layer[:, request.slot_mapping, ...] = kv_cache
else:
dst_kv_cache_layer[:, request.slot_mapping[inject_start_index:inject_start_index + num_token],
...] = kv_cache
inject_start_index += num_token
# inject_kv_into_layer(kv_cache_layer, kv_cache,
# request.slot_mapping, request.request_id)
tensor_id = request.request_id + "#" + layer_name + "#" + str(num)
if tensor_id in self.du_swift_engine.recv_store:
tensor = self.du_swift_engine.recv_store.pop(tensor_id, None)
self.du_swift_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.du_swift_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.du_swift_engine.pool.free(addr)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.du_swift_engine is not None
is_mla = isinstance(attn_metadata, MLACommonMetadata)
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata) or kv_layer.ndim == 3:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, DuSwiftConnectorMetadata)
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
slot_mapping = request.slot_mapping
if request.slot_mapping_device is None:
request.slot_mapping_device = \
request.slot_mapping.pin_memory().to(device=kv_layer.device, non_blocking=True)
slot_mapping = request.slot_mapping_device
tbo_evt = torch.cuda.Event(enable_timing=False)
tbo_evt.record()
pp_rank = (self.parallel_config.rank //
self.parallel_config.tensor_parallel_size) % \
self.parallel_config.pipeline_parallel_size
if (self.pp_size == 1):
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
elif (self.pp_size == 2):
if (pp_rank == 0):
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank + 4), tbo_evt)
else:
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), remote_address, tbo_evt)
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + self._rank - 4), tbo_evt)
elif (self.pp_size == 8):
for i in range(8):
self.du_swift_engine.p2p_async_send_tensor(request_id + "#" + layer_name,
(kv_layer, slot_mapping), ip + ":" + str(port + i), tbo_evt)
else:
print("Error: only suppprt pp1 pp2 pp8!!!!!!")
else:
for request in connector_metadata.requests:
request_id = request.request_id
# ip, port = self.parse_request_id(request_id, True)
# p_ip, p_port = self.parse_request_id(request_id, False)
# remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
# pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
# ) % self.parallel_config.pipeline_parallel_size
# if (self.multiple_machines_p and self.multiple_machines_d):
# ip_second = self.get_ip_value(ip)
# if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2):
# if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank))
# else:
# logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.multiple_machines_p and not self.multiple_machines_d):
# if (self.pp_size == 2):
# remote_address = ip + ":" + str(port + self._tp_rank)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address)
# else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
# elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_addr)
def wait_for_save(self):
pass
# if self.is_producer:
# assert self.du_swift_engine is not None
# self.du_swift_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str],
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.du_swift_engine is not None
forward_context: ForwardContext = get_forward_context()
return self.du_swift_engine.get_finished(finished_req_ids,
forward_context)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = DuSwiftConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[new_req.req_id]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.prompt_token_ids)
continue
# the request's prompt is not chunked prefill
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
# assert req_id in self.chunked_prefill
if req_id not in self.chunked_prefill:
continue
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[req_id] = (block_ids,
prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import os
import threading
import time
import typing
from collections import deque, defaultdict
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
import msgpack
import torch
import zmq
import regex
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.du.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm.utils.torch_utils import current_stream
from vllm.utils.network_utils import get_ip
from vllm import envs
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from dataclasses import dataclass, field
from vllm.model_executor.models.utils import extract_layer_index
from vllm.distributed.utils import get_pp_indices
from vllm.config import ModelConfig
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
# @dataclass
# class SendQueueItem:
# tensor_id: str
# remote_address: str
# tensor: torch.Tensor
@contextmanager
def set_du_swift_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
'NCCL_MAX_NCHANNELS',
'NCCL_MIN_NCHANNELS',
'NCCL_CUMEM_ENABLE',
'NCCL_BUFFSIZE',
'NCCL_PROTO', # LL,LL128,SIMPLE
'NCCL_ALGO', # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_du_swift_context, original_values: %s", original_values)
try:
os.environ['NCCL_MAX_NCHANNELS'] = num_channels
os.environ['NCCL_MIN_NCHANNELS'] = num_channels
os.environ['NCCL_CUMEM_ENABLE'] = '1'
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass
class RemoteAddr:
pd_pair_id: str = ""
zmq_address: str = ""
comm_rank: int = 0
class DuSwiftEngineDp:
def __init__(self,
local_rank: int,
port_offset: int,
config: KVTransferConfig,
model_config: ModelConfig,
dp_rank: int = 0,
pp_rank: int = 0,
tp_rank: int = 0,
dp_size: int = 0,
pp_size: int = 0,
tp_size: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.model_config = model_config
self.rank = port_offset
self.local_rank = local_rank
self.dp_rank = dp_rank
self.pp_rank = pp_rank
self.tp_rank = tp_rank
self.dp_size = dp_size
self.pp_size = pp_size
self.tp_size = tp_size
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
self.total_num_hidden_layers = getattr(self.model_config.hf_text_config,
"num_hidden_layers", 0)
self.pp_rank = get_pp_group().rank_in_group
self.tp_rank = get_tp_group().rank_in_group
self.pp_size = get_pp_group().world_size
self.tp_size = get_tp_group().world_size
if config.is_kv_producer:
self.remote_tp_size = self.config.get_from_extra_config(
"remote_tp_size", 1)
self.remote_pp_size = self.config.get_from_extra_config(
"remote_pp_size", 1)
self.enable_asymmetric_p2p = self.config.get_from_extra_config(
"enable_asymmetric_p2p", False)
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = get_ip()
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
if self.instance_ip:
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
self.p2p_async_kv_tokens = envs.VLLM_P2P_BUF_TOKENS
self.p2p_async_buf = None
self.tensor_split_num: int = 0
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
1024**3) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
daemon=True)
self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.nccl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if self.multiple_machines:
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread.start()
else:
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping_new,
daemon=True)
self._ping_thread.start()
logger.info(
"💯DuSwiftEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect_new(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt(zmq.SNDHWM, 10000)
sock.setsockopt(zmq.RCVHWM, 5000)
sock.setsockopt(zmq.LINGER, 0)
sock.setsockopt(zmq.TCP_KEEPALIVE, 1)
sock.setsockopt_string(zmq.IDENTITY, f"P-{self.zmq_address}")
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
return self.socks[remote_address]
def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.nccl.ncclGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
rank = 0
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
def get_send_queue_items(self, request_id: str, layer_name: str,
tensor: torch.Tensor,
is_mla: bool) -> list[any]:
tensor_id = self.get_tensor_id(request_id, layer_name)
remote_ip, remote_port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
# logger.info(f"""+++++xiabo tensor_id:{tensor_id} request_id:{request_id} remote_address:{remote_address}""")
return [(tensor_id, remote_addr, tensor)]
if not is_mla:
logger.error(" DuSwift only support mla model symmetric PP/TP!!!!")
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
remote_port_offset = remote_pp_rank * self.remote_tp_size + d_tp_rank
remote_address = remote_ip + ":" + str(remote_port + remote_port_offset)
remote_addr = RemoteAddr(pd_pair_id, remote_address, remote_port_offset + self.pp_size * self.tp_size)
logger.debug(
"Wait to send::%s, tensor_shape:%s, "
"(pp=%d, tp=%d) -> remote_address=%s(pp=%d, tp=%d) comm_rank (%d -> %d)", tensor_id,
tensor.shape, self.pp_rank, self.tp_rank, remote_address,
remote_pp_rank, self.rank * mul_tp + self.rank, self.rank, remote_port_offset + self.pp_size * self.tp_size)
items.append([tensor_id, remote_addr, tensor])
return items
def send_tensor_new(
self,
request_id: str,
layer_name: str,
tensor: torch.Tensor,
is_mla: bool = False,
) -> bool:
tensor_id = self.get_tensor_id(request_id, layer_name)
if self.send_type == "PUT":
return all(
self._send_sync_new(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
for item in self.get_send_queue_items(request_id, layer_name,
tensor, is_mla):
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
if self.send_type == "GET":
logger.error(" DuSwift new not support GET model, please set VLLM_P2PNCCL_NEW=0 use defalut model!!!!")
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def p2p_async_send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
tbo_evt = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
kv_layer, slot_mapping = tensor # tesor (kv_layer, slot_mapping)
self.send_queue.append([tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape,
self.device)
else:
self.buffer_size -= (tensor.element_size() *
tensor.numel())
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d", remote_address, tensor_id, duration * 1000,
self.rank)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address, tensor_id, data["ret"])
return None
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def _listen_for_requests(self):
while True:
socks = dict(self.poller.poll(5000))
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
logger.info(f"unexpected message from {remote_address.decode()}")
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "PUT_NEW":
tensor_id = data["tensor_id"]
if "tensor_split_num" in data:
self.tensor_split_num = data["tensor_split_num"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
# comm, rank = self.comms[remote_address.decode()]
# self._recv(comm, tensor, rank ^ 1, self.recv_stream)
comm, rank = self.comms[data["pd_pair_id"]]
self._recv(comm, tensor, int(data["comm_rank"]), self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "comm_init":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = int(data["rank"])
world_size = int(data["world_size"])
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
world_size, unique_id, rank)
self.comms[data["pd_pair_id"]] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype":
str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def _have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
if envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC:
tensor_id, remote_address, kv_layer, slot_mapping, tbo_evt = self.send_queue.popleft()
else:
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
if self.multiple_machines:
self._send_sync(tensor_id, tensor, remote_address)
else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_kv_p2p_sync(self, tensor_id: str, kv_layer: torch.Tensor,
slot_mapping: torch.Tensor, remote_address: str) -> bool:
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
is_mla = (kv_layer.ndim == 3)
hidden_dim = kv_layer.shape[-1]
if self.p2p_async_buf is None:
if is_mla:
self.p2p_async_buf = torch.empty((self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
else:
self.p2p_async_buf = torch.empty((2, self.p2p_async_kv_tokens, hidden_dim),
dtype=kv_layer.dtype, device=kv_layer.device)
pack_num = (slot_mapping.shape[0] - 1) // self.p2p_async_kv_tokens + 1
self.tensor_split_num = pack_num
with torch.cuda.stream(self.send_stream):
for pack_idx in range(pack_num):
start = pack_idx * self.p2p_async_kv_tokens
end = min((pack_idx + 1) * self.p2p_async_kv_tokens, slot_mapping.shape[0])
sub_index = slot_mapping[start:end]
if is_mla:
num_pages, page_size = kv_layer.shape[0], kv_layer.shape[1]
data = kv_layer.reshape(num_pages * page_size, -1)
torch.index_select(data, dim=0, index=sub_index, out=self.p2p_async_buf[:end-start])
tx_shape = (end - start, hidden_dim)
else:
num_pages, page_size = kv_layer.shape[1], kv_layer.shape[2]
data = kv_layer.reshape(2, num_pages * page_size, -1)
torch.index_select(data, dim=1, index=sub_index, out=self.p2p_async_buf[:, :end-start])
tx_shape = (2, end - start, hidden_dim)
if is_mla:
send_tensor = self.p2p_async_buf[:end-start]
else:
send_tensor = self.p2p_async_buf[:, :end-start]
header = {
"cmd": "PUT",
"tensor_id": tensor_id + "#" + str(pack_idx), # 拼 pack_idx
"pack_idx": pack_idx,
"tensor_split_num": pack_num,
"shape": tx_shape,
"dtype": str(kv_layer.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(header))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor Failed | %s 👉 %s | Rank:%s | shape:%s | size:%.4f GB | response:%s",
self.zmq_address, remote_address, rank,
tuple(send_tensor.shape), send_tensor.element_size() * send_tensor.numel() / 1024**3,
response.decode()
)
return False
self._send(comm, send_tensor, rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync_new(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[RemoteAddr] = None,
) -> bool:
if remote_address is None:
return False
if remote_address.zmq_address not in self.socks:
# logger.info(f"""=============xiabo remote_address.zmq_address:{remote_address.zmq_address}""")
self._create_connect_new(remote_address.zmq_address)
sock = self.socks[remote_address.zmq_address]
comm, rank = self.comms[remote_address.pd_pair_id]
data = {
"cmd": "PUT_NEW",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", ""),
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address.zmq_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), remote_address.comm_rank, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def _send_sync(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
return False
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {
"cmd": "PUT",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext"
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids:
with self.pending_queue_cv:
if request_id in self.pending_queue:
self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None)
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None
def _ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _ping_new(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
if self.rank == 0:
data = {
"type": "P_init" if self.config.is_kv_producer else "D_init",
"http_address": self.http_address,
"zmq_address": self.zmq_address,
"dp_size" : self.dp_size,
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
}
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"dp_rank" : self.dp_rank,
"pp_rank" : self.pp_rank,
"tp_rank" : self.tp_rank,
"zmq_address": self.zmq_address
}
# while True:
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
# time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
def get_pp_indices_d(self, num_hidden_layers: int, pp_rank: int,
pp_size: int) -> tuple[int, int]:
partition_list_str = envs.VLLM_PP_LAYER_PARTITION_D
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable",
",".join(str(p) for p in partitions))
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = self.get_pp_indices_d(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
if start <= current_layer_idx < end:
return d_pp_rank
return -1
@staticmethod
def get_tensor_id(request_id: str, layer_name: str) -> str:
return request_id + "#" + layer_name
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = regex.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
...@@ -1841,6 +1841,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1841,6 +1841,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))), lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
#vllm use dp connector
"VLLM_USE_DP_CONNECTOR":
lambda: bool(int(os.getenv("VLLM_USE_DP_CONNECTOR", "0"))),
# vllm pd separation will be used async # vllm pd separation will be used async
"VLLM_P2P_ASYNC": "VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))), lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
......
...@@ -121,7 +121,7 @@ class Scheduler(SchedulerInterface): ...@@ -121,7 +121,7 @@ class Scheduler(SchedulerInterface):
config=self.vllm_config, config=self.vllm_config,
role=KVConnectorRole.SCHEDULER, role=KVConnectorRole.SCHEDULER,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) dp_rank=self.parallel_config.data_parallel_rank)
if self.log_stats: if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats() self.connector_prefix_cache_stats = PrefixCacheStats()
kv_load_failure_policy = ( kv_load_failure_policy = (
...@@ -556,6 +556,12 @@ class Scheduler(SchedulerInterface): ...@@ -556,6 +556,12 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running): + len(scheduled_running_reqs) >= max_batch_running):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and \
request.request_id not in self.finished_recving_kv_req_ids and \
envs.VLLM_USE_DP_CONNECTOR:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
......
...@@ -66,6 +66,7 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder ...@@ -66,6 +66,7 @@ from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import compute_iteration_details from vllm.v1.utils import compute_iteration_details
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm import envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1155,6 +1156,11 @@ class EngineCoreProc(EngineCore): ...@@ -1155,6 +1156,11 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, tuple) and self.scheduler.connector is not None \
and envs.VLLM_USE_DP_CONNECTOR:
req, _ = request
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(req.request_id)
def process_output_sockets( def process_output_sockets(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment