Commit fa3bae2e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-dp' into 'v0.9.2-dev'

Merge v0.9.2-dev-dp into v0.9.2-dev

See merge request dcutoolkit/deeplearing/vllm!445
parents ffd26247 6f866c45
...@@ -13,6 +13,8 @@ from typing import Any ...@@ -13,6 +13,8 @@ from typing import Any
from quart import Quart, make_response, request from quart import Quart, make_response, request
from dataclasses import dataclass, field from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
import time
import asyncio
from collections import deque, defaultdict from collections import deque, defaultdict
import logging import logging
logging.basicConfig( logging.basicConfig(
...@@ -21,13 +23,13 @@ logging.basicConfig( ...@@ -21,13 +23,13 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# @dataclass @dataclass
# class Request: class Request:
# request_id: str request_id: str
# p_http_address: str = "" p_http_address: str = ""
# p_dp_rank: int = -1 p_dp_rank: int = -1
# d_http_address: str = "" d_http_address: str = ""
# d_dp_rank: int = -1 d_dp_rank: int = -1
@dataclass @dataclass
class Instance: class Instance:
...@@ -60,17 +62,16 @@ class Instance: ...@@ -60,17 +62,16 @@ class Instance:
all_ranks_ready = world_size and inited_rank == world_size all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" : if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""") logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready return all_ranks_ready
else : else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""") logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready return all_ranks_ready
count = 0 count = 0
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {} prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {} decode_instances: dict[str, Instance] = {}
running_requests: dict[str, Request] = {}
healthy_instances: dict[str, float] = {}
pending_prefill_ins: list[str] = [] pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = [] pending_decode_ins: list[str] = []
...@@ -80,10 +81,13 @@ ready_decode_ins: list[str] = [] ...@@ -80,10 +81,13 @@ ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {} pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary() router_nccl = NCCLLibrary()
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
instance_cv = threading.Condition() instance_cv = threading.Condition()
request_cv = threading.Condition()
health_cv = threading.Condition()
request_queue_cv = threading.Condition()
request_queue: deque[list[Any]] = deque()
sock_cache : dict[str, Any] = {} sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
...@@ -149,6 +153,35 @@ def _listen_for_register(poller, router_socket): ...@@ -149,6 +153,35 @@ def _listen_for_register(poller, router_socket):
pending_decode_ins.append(d_instance.http_address) pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""") logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify() instance_cv.notify()
elif data["type"] == "heartbeat":
global healthy_instances
global health_cv
with health_cv:
healthy_instances[data["http_address"]] = time.time()
elif data["type"] == "Req":
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global running_requests
global request_cv
with request_cv:
if data["request_id"] in running_requests:
request = running_requests[data["request_id"]]
if data["instance_type"] == "P":
request.p_http_address = data["http_address"]
request.p_dp_rank = int(data["dp_rank"])
elif data["instance_type"] == "D":
request.d_http_address = data["http_address"]
request.d_dp_rank = int(data["dp_rank"])
assert(request.p_dp_rank >= 0 and request.d_dp_rank >=0)
with request_queue_cv:
request_queue.append(request)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv.notify()
else:
if data["instance_type"] == "P":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], p_http_address=data["http_address"], p_dp_rank=int(data["dp_rank"]))
elif data["instance_type"] == "D":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], d_http_address=data["http_address"], d_dp_rank=int(data["dp_rank"]))
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
...@@ -157,6 +190,9 @@ def _listen_for_register(poller, router_socket): ...@@ -157,6 +190,9 @@ def _listen_for_register(poller, router_socket):
) )
zmq_context = None zmq_context = None
tp_mapping_of_pd_pair : dict[str, dict[int, list[str]]] = {}
tp_comm_mapping_of_pd_pair : dict[str, dict[int, list[int]]] = {}
active_p_tp_rank_of_pd_pair : dict[str, set[int]] = {}
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
...@@ -180,6 +216,91 @@ def start_service_discovery(hostname, port): ...@@ -180,6 +216,91 @@ def start_service_discovery(hostname, port):
_listener_thread.start() _listener_thread.start()
return _listener_thread return _listener_thread
def dispatch_to_P(request : Request):
global prefill_instances
global decode_instances
p_ins = prefill_instances[request.p_http_address]
d_ins = decode_instances[request.d_http_address]
global zmq_context
global sock_cache
pd_pair_id = p_ins.http_address + "_" + d_ins.http_address
p_dp_rank = request.p_dp_rank
d_dp_rank = request.d_dp_rank
tp_dst_id = pd_pair_id + "_" + str(d_dp_rank)
assert(d_ins.pp_size == 1)
d_pp_rank = 0
global tp_mapping_of_pd_pair
global tp_comm_mapping_of_pd_pair
global active_p_tp_rank_of_pd_pair
if tp_dst_id not in active_p_tp_rank_of_pd_pair:
p_active_tp_rank = set()
p_tp_rank_to_dst : dict[int, list[str]] = defaultdict(list)
p_tp_rank_to_dst_comm : dict[int, list[int]] = defaultdict(list)
for d_tp_rank in range(d_ins.tp_size):
p_tp_rank = d_tp_rank % p_ins.tp_size
p_active_tp_rank.add(p_tp_rank)
p_tp_rank_to_dst[p_tp_rank].append(d_ins.rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
p_tp_rank_to_dst_comm[p_tp_rank].append(d_ins.comm_rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
tp_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair[tp_dst_id] = p_active_tp_rank
p_active_tp_rank = active_p_tp_rank_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst = tp_mapping_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst_comm = tp_comm_mapping_of_pd_pair[tp_dst_id]
for p_pp_rank in range(p_ins.pp_size):
for p_tp_rank in p_active_tp_rank:
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_to_transfer",
"request_id": request.request_id,
"dst_num": len(p_tp_rank_to_dst[p_tp_rank]),
"pd_pair_id": pd_pair_id,
"remote_address": p_tp_rank_to_dst[p_tp_rank],
"remote_rank": p_tp_rank_to_dst_comm[p_tp_rank],
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
logger.info(f"""[Router] dispatch Request {request.request_id} [{p_dp_rank}, {p_pp_rank}, {p_tp_rank}] -> [{d_dp_rank}, {d_pp_rank}]""")
for p_tp_rank in range(p_ins.tp_size):
if p_tp_rank not in p_active_tp_rank:
for p_pp_rank in range(p_ins.pp_size):
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_not_to_transfer",
"request_id": request.request_id,
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
def dp_dispatch():
global request_queue_cv
global request_queue
while True:
with request_queue_cv:
while not request_queue:
request_queue_cv.wait()
request = request_queue.pop()
dispatch_to_P(request)
def start_dp_dispatch():
_thread = threading.Thread(
target=dp_dispatch, daemon=True
)
_thread.start()
return _thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
...@@ -206,14 +327,14 @@ async def forward_request(url, data, request_id): ...@@ -206,14 +327,14 @@ async def forward_request(url, data, request_id):
yield content yield content
def unique_id_dispatch(prefill_instance : str, def unique_id_dispatch(prefill_instance : Instance,
decode_instance : str) : decode_instance : Instance) :
global zmq_context global zmq_context
global sock_cache global sock_cache
global router_nccl global router_nccl
global pd_pair global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address pd_pair_id = prefill_instance.http_address + "_" + decode_instance.http_address
if pd_pair_id in pd_pair: if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""") logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
...@@ -320,35 +441,34 @@ async def handle_request(): ...@@ -320,35 +441,34 @@ async def handle_request():
global count global count
global prefill_instances global prefill_instances
global prefill_cv global instance_cv
with prefill_cv: with instance_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)] prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances global decode_instances
global decode_cv with instance_cv:
with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_instance = decode_list[count % len(decode_list)] decode_addr, decode_instance = decode_list[count % len(decode_list)]
print( global pd_pair
f"handle_request count: {count}, [HTTP:{prefill_addr}, " if prefill_instance.http_address + "_" + decode_instance.http_address not in pd_pair:
f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, " raise RuntimeError("Selected PD pair was not inited")
f"ZMQ:{decode_instance.zmq_address}]" logger.info(
f"handle_request count: {count}, [HTTP:{prefill_addr}, 👉 HTTP:{decode_addr}]"
) )
count += 1 count += 1
request_id = ( request_id = f"{random_uuid()}"
f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_instance.zmq_address}_{random_uuid()}"
)
# finish prefill async def run_prefill():
async for _ in forward_request( async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id f"http://{prefill_addr}/v1/completions", prefill_request, request_id
): ):
continue pass
prefill_task = asyncio.create_task(run_prefill())
# return decode # return decode
generator = forward_request( generator = forward_request(
...@@ -372,6 +492,8 @@ async def handle_request(): ...@@ -372,6 +492,8 @@ async def handle_request():
if __name__ == "__main__": if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001) t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init() t_1 = start_pd_pair_init()
t_2 = start_dp_dispatch()
app.run(host="0.0.0.0", port=10001) app.run(host="0.0.0.0", port=10001)
t.join() t.join()
t_1.join() t_1.join()
t_2.join()
...@@ -54,6 +54,7 @@ class KVConnectorFactory: ...@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1: ) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, " raise ValueError("Attempting to initialize a V1 Connector, "
...@@ -81,7 +82,7 @@ class KVConnectorFactory: ...@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role, dp_rank)
# Register various connectors here. # Register various connectors here.
......
...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata): ...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class DuSwiftConnector(KVConnectorBase_V1): class DuSwiftConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
...@@ -157,9 +160,39 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -157,9 +160,39 @@ class DuSwiftConnector(KVConnectorBase_V1):
except Exception as e: except Exception as e:
print(f"Error: Exception occurred while reading configuration file - {e}") print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key): def get_ip_value(self, key):
return self.ip_map.get(key) return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ============================== # ==============================
...@@ -438,43 +471,57 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -438,43 +471,57 @@ class DuSwiftConnector(KVConnectorBase_V1):
else: else:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) # ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False) # p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank) # remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port # pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size # pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size # ) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d): # if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip) # ip_second = self.get_ip_value(ip)
if (self.pp_size == 1): # if (self.pp_size == 1):
if self._rank < 8: # if self._rank < 8:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) # kv_cache, remote_address)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8)) # kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2): # elif (self.pp_size == 2):
if (pp_rank == 0): # if (pp_rank == 0):
self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) # kv_cache, remote_address)
else: # else:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank)) # kv_cache, str(ip_second) + ":" + str(port + self._rank))
else: # else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!") # logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d): # elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2): # if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank) # remote_address = ip + ":" + str(port + self._tp_rank)
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name, self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address) kv_cache, remote_addr)
else: # kv_cache, remote_address)
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!") # else:
# logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card) # elif (not self.multiple_machines_p and not self.multiple_machines_d):
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache, # # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
is_mla) # self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# if (self.pp_size == 1): # if (self.pp_size == 1):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
...@@ -498,8 +545,8 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -498,8 +545,8 @@ class DuSwiftConnector(KVConnectorBase_V1):
# kv_cache, remote_address) # kv_cache, remote_address)
# else: # else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!") # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
else: # else:
logger.error("Error: not support!!!!!!") # logger.error("Error: not support!!!!!!")
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
......
...@@ -6,9 +6,10 @@ import os ...@@ -6,9 +6,10 @@ import os
import threading import threading
import time import time
import typing import typing
from collections import deque from collections import deque, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack import msgpack
import torch import torch
...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str): ...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass @dataclass
class RemoteAddr: class RemoteAddr:
pd_pair_id: str = "" pd_pair_id: str = ""
...@@ -125,6 +132,10 @@ class DuSwiftEngine: ...@@ -125,6 +132,10 @@ class DuSwiftEngine:
self.multp = int(self.remote_tp_size / self.tp_size) self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config( self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False) "enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
...@@ -135,9 +146,14 @@ class DuSwiftEngine: ...@@ -135,9 +146,14 @@ class DuSwiftEngine:
self.zmq_address = f"{self._hostname}:{self._port}" self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI. # The `http_port` must be consistent with the port of OpenAI.
self.http_address = ( if self.instance_ip:
f"{self._hostname}:" self.http_address = (
f"{self.config.kv_connector_extra_config['http_port']}") f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`, # If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled. # then the ping thread will not be enabled.
...@@ -148,17 +164,28 @@ class DuSwiftEngine: ...@@ -148,17 +164,28 @@ class DuSwiftEngine:
else: else:
self.proxy_address = proxy_ip + ":" + proxy_port self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context() self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER) self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}") self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition() self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition() self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
...@@ -181,11 +208,16 @@ class DuSwiftEngine: ...@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self._send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape) # tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {} self.recv_store: dict[str, Any] = {}
...@@ -328,7 +360,7 @@ class DuSwiftEngine: ...@@ -328,7 +360,7 @@ class DuSwiftEngine:
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None, tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
...@@ -356,7 +388,7 @@ class DuSwiftEngine: ...@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank) self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
...@@ -364,7 +396,7 @@ class DuSwiftEngine: ...@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape, remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size, self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
...@@ -417,6 +449,55 @@ class DuSwiftEngine: ...@@ -417,6 +449,55 @@ class DuSwiftEngine:
return True return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor( def recv_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -475,22 +556,12 @@ class DuSwiftEngine: ...@@ -475,22 +556,12 @@ class DuSwiftEngine:
def _listen_for_requests(self): def _listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll(5000))
if self.router_socket in socks: if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart() remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes( logger.info(f"unexpected message from {remote_address.decode()}")
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data: if "tensor_split_num" in data:
...@@ -577,6 +648,15 @@ class DuSwiftEngine: ...@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger.info( logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank) self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
with self.send_store_cv: with self.send_store_cv:
...@@ -814,27 +894,51 @@ class DuSwiftEngine: ...@@ -814,27 +894,51 @@ class DuSwiftEngine:
""" """
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: with self.pending_queue_cv:
tensor_id = request_id + "#" + layer_name if request_id in self.pending_queue:
if tensor_id in self.recv_store: self.requests_to_release[request_id] = False
with self.recv_store_cv: logger.info("[%d] pending request: %s", self.rank, request_id)
tensor = self.recv_store.pop(tensor_id, None) continue
self.send_request_id_to_tensor_ids.pop( requests_to_release.append(request_id)
request_id, None)
self.recv_request_id_to_tensor_ids.pop( for request_id in requests_to_release:
request_id, None) ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
addr = 0 with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None)
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache. # TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set() finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache. # TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set() finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
...@@ -911,6 +1015,7 @@ class DuSwiftEngine: ...@@ -911,6 +1015,7 @@ class DuSwiftEngine:
self._listener_thread.join() self._listener_thread.join()
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread.join() self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
......
...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface): ...@@ -86,7 +86,8 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
"with KV connectors") "with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
...@@ -380,6 +381,10 @@ class Scheduler(SchedulerInterface): ...@@ -380,6 +381,10 @@ class Scheduler(SchedulerInterface):
if request.is_finished(): if request.is_finished():
self.waiting.pop_request() self.waiting.pop_request()
continue continue
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -674,6 +679,11 @@ class Scheduler(SchedulerInterface): ...@@ -674,6 +679,11 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -1326,7 +1336,7 @@ class Scheduler(SchedulerInterface): ...@@ -1326,7 +1336,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
......
...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore): ...@@ -763,6 +763,9 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str], def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str], coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment