Commit 236266a9 authored by xuxz's avatar xuxz
Browse files

[PD]支持dp的分支

parent fb597c49
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from typing import Any
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
import time
import asyncio
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@dataclass
class Request:
request_id: str
p_http_address: str = ""
p_dp_rank: int = -1
d_http_address: str = ""
d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
running_requests: dict[str, Request] = {}
healthy_instances: dict[str, float] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
instance_cv = threading.Condition()
request_cv = threading.Condition()
health_cv = threading.Condition()
request_queue_cv = threading.Condition()
request_queue: deque[list[Any]] = deque()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "heartbeat":
global healthy_instances
global health_cv
with health_cv:
healthy_instances[data["http_address"]] = time.time()
elif data["type"] == "Req":
# logger.info(f"""[Router] recv Request {data["request_id"]} : {data["instance_type"]}""")
global running_requests
global request_cv
with request_cv:
if data["request_id"] in running_requests:
request = running_requests[data["request_id"]]
if data["instance_type"] == "P":
request.p_http_address = data["http_address"]
request.p_dp_rank = int(data["dp_rank"])
elif data["instance_type"] == "D":
request.d_http_address = data["http_address"]
request.d_dp_rank = int(data["dp_rank"])
assert(request.p_dp_rank >= 0 and request.d_dp_rank >=0)
with request_queue_cv:
request_queue.append(request)
# logger.info(f"""[Router] add Request {data["request_id"]} [{request.p_http_address}:{request.p_dp_rank}, {request.d_http_address}:{request.d_dp_rank}]""")
request_queue_cv.notify()
else:
if data["instance_type"] == "P":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], p_http_address=data["http_address"], p_dp_rank=int(data["dp_rank"]))
elif data["instance_type"] == "D":
running_requests[data["request_id"]] = Request(request_id=data["request_id"], d_http_address=data["http_address"], d_dp_rank=int(data["dp_rank"]))
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
zmq_context = None
tp_mapping_of_pd_pair : dict[str, dict[int, list[str]]] = {}
tp_comm_mapping_of_pd_pair : dict[str, dict[int, list[int]]] = {}
active_p_tp_rank_of_pd_pair : dict[str, set[int]] = {}
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
def dispatch_to_P(request : Request):
global prefill_instances
global decode_instances
p_ins = prefill_instances[request.p_http_address]
d_ins = decode_instances[request.d_http_address]
global zmq_context
global sock_cache
pd_pair_id = p_ins.http_address + "_" + d_ins.http_address
p_dp_rank = request.p_dp_rank
d_dp_rank = request.d_dp_rank
tp_dst_id = pd_pair_id + "_" + str(d_dp_rank)
assert(d_ins.pp_size == 1)
d_pp_rank = 0
global tp_mapping_of_pd_pair
global tp_comm_mapping_of_pd_pair
global active_p_tp_rank_of_pd_pair
if tp_dst_id not in active_p_tp_rank_of_pd_pair:
p_active_tp_rank = set()
p_tp_rank_to_dst : dict[int, list[str]] = defaultdict(list)
p_tp_rank_to_dst_comm : dict[int, list[int]] = defaultdict(list)
for d_tp_rank in range(d_ins.tp_size):
p_tp_rank = d_tp_rank % p_ins.tp_size
p_active_tp_rank.add(p_tp_rank)
p_tp_rank_to_dst[p_tp_rank].append(d_ins.rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
p_tp_rank_to_dst_comm[p_tp_rank].append(d_ins.comm_rank_table[d_dp_rank][d_pp_rank][d_tp_rank])
tp_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst
tp_comm_mapping_of_pd_pair[tp_dst_id] = p_tp_rank_to_dst_comm
active_p_tp_rank_of_pd_pair[tp_dst_id] = p_active_tp_rank
p_active_tp_rank = active_p_tp_rank_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst = tp_mapping_of_pd_pair[tp_dst_id]
p_tp_rank_to_dst_comm = tp_comm_mapping_of_pd_pair[tp_dst_id]
for p_pp_rank in range(p_ins.pp_size):
for p_tp_rank in p_active_tp_rank:
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_to_transfer",
"request_id": request.request_id,
"dst_num": len(p_tp_rank_to_dst[p_tp_rank]),
"pd_pair_id": pd_pair_id,
"remote_address": p_tp_rank_to_dst[p_tp_rank],
"remote_rank": p_tp_rank_to_dst_comm[p_tp_rank],
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
logger.info(f"""[Router] dispatch Request {request.request_id} [{p_dp_rank}, {p_pp_rank}, {p_tp_rank}] -> [{d_dp_rank}, {d_pp_rank}]""")
for p_tp_rank in range(p_ins.tp_size):
if p_tp_rank not in p_active_tp_rank:
for p_pp_rank in range(p_ins.pp_size):
if p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]}")
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]] = sock
data = {
"cmd": "req_not_to_transfer",
"request_id": request.request_id,
}
sock_cache[p_ins.rank_table[p_dp_rank][p_pp_rank][p_tp_rank]].send(msgpack.dumps(data))
def dp_dispatch():
global request_queue_cv
global request_queue
while True:
with request_queue_cv:
while not request_queue:
request_queue_cv.wait()
request = request_queue.pop()
dispatch_to_P(request)
def start_dp_dispatch():
_thread = threading.Thread(
target=dp_dispatch, daemon=True
)
_thread.start()
return _thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
def unique_id_dispatch(prefill_instance : Instance,
decode_instance : Instance) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.http_address + "_" + decode_instance.http_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global instance_cv
with instance_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
with instance_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_instance = decode_list[count % len(decode_list)]
global pd_pair
if prefill_instance.http_address + "_" + decode_instance.http_address not in pd_pair:
raise RuntimeError("Selected PD pair was not inited")
logger.info(
f"handle_request count: {count}, [HTTP:{prefill_addr}, 👉 HTTP:{decode_addr}]"
)
count += 1
request_id = f"{random_uuid()}"
async def run_prefill():
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
pass
prefill_task = asyncio.create_task(run_prefill())
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
t_2 = start_dp_dispatch()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
t_2.join()
\ No newline at end of file
...@@ -54,6 +54,7 @@ class KVConnectorFactory: ...@@ -54,6 +54,7 @@ class KVConnectorFactory:
cls, cls,
config: "VllmConfig", config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
dp_rank: int = -1,
) -> KVConnectorBase_V1: ) -> KVConnectorBase_V1:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Attempting to initialize a V1 Connector, " raise ValueError("Attempting to initialize a V1 Connector, "
...@@ -81,7 +82,7 @@ class KVConnectorFactory: ...@@ -81,7 +82,7 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role, dp_rank)
# Register various connectors here. # Register various connectors here.
......
...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -20,6 +20,9 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group, get_dp_group
import zmq
import msgpack
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata): ...@@ -78,7 +81,7 @@ class DuSwiftConnectorMetadata(KVConnectorMetadata):
class DuSwiftConnector(KVConnectorBase_V1): class DuSwiftConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole, dp_rank : int = -1):
super().__init__(vllm_config=vllm_config, role=role) super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
...@@ -158,9 +161,38 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -158,9 +161,38 @@ class DuSwiftConnector(KVConnectorBase_V1):
print(f"Error: Exception occurred while reading configuration file - {e}") print(f"Error: Exception occurred while reading configuration file - {e}")
if role == KVConnectorRole.SCHEDULER :
self.dp_rank = dp_rank
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
self.context = zmq.Context()
req_sock = self.context.socket(zmq.DEALER)
req_sock.setsockopt_string(zmq.IDENTITY, f"{self.http_address}_rank{self.dp_rank}")
req_sock.connect(f"tcp://{self.proxy_address}")
self.req_sock = req_sock
def get_ip_value(self, key): def get_ip_value(self, key):
return self.ip_map.get(key) return self.ip_map.get(key)
def register_req(self, request_id: str) :
data = {
"type": "Req",
"instance_type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"request_id": request_id,
"dp_rank" : self.dp_rank
}
self.req_sock.send(msgpack.dumps(data))
# ============================== # ==============================
# Worker-side methods # Worker-side methods
...@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1): ...@@ -438,68 +470,83 @@ class DuSwiftConnector(KVConnectorBase_V1):
else: else:
for request in connector_metadata.requests: for request in connector_metadata.requests:
request_id = request.request_id request_id = request.request_id
ip, port = self.parse_request_id(request_id, True) # ip, port = self.parse_request_id(request_id, True)
p_ip, p_port = self.parse_request_id(request_id, False) # p_ip, p_port = self.parse_request_id(request_id, False)
remote_address = ip + ":" + str(port + self._rank) # remote_address = ip + ":" + str(port + self._rank)
# pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port # pd_pair_id = p_ip + ":" + p_port + "_" + ip + ":" + port
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size # pp_rank = (self.parallel_config.rank // self.parallel_config.tensor_parallel_size
) % self.parallel_config.pipeline_parallel_size # ) % self.parallel_config.pipeline_parallel_size
if (self.multiple_machines_p and self.multiple_machines_d): # if (self.multiple_machines_p and self.multiple_machines_d):
ip_second = self.get_ip_value(ip) # ip_second = self.get_ip_value(ip)
if (self.pp_size == 1):
if self._rank < 8:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
elif (self.pp_size == 2):
if (pp_rank == 0):
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, str(ip_second) + ":" + str(port + self._rank))
else:
logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
elif (self.multiple_machines_p and not self.multiple_machines_d):
if (self.pp_size == 2):
remote_address = ip + ":" + str(port + self._tp_rank)
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
else:
logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
elif (not self.multiple_machines_p and not self.multiple_machines_d):
# remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
is_mla)
# if (self.pp_size == 1): # if (self.pp_size == 1):
# if self._rank < 8:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, str(ip_second) + ":" + str(port + self._rank + 8))
# elif (self.pp_size == 2): # elif (self.pp_size == 2):
# if (pp_rank == 0): # if (pp_rank == 0):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, ip + ":" + str(port + self._rank + 4))
# else: # else:
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, str(ip_second) + ":" + str(port + self._rank))
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # else:
# kv_cache, ip + ":" + str(port + self._rank - 4)) # logger.error("Error: multiple machines only suppprt pp1tp16 and pp2tp8!!!!!!")
# elif (self.pp_size == 8): # elif (self.multiple_machines_p and not self.multiple_machines_d):
# for i in range(8): # if (self.pp_size == 2):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # remote_address = ip + ":" + str(port + self._tp_rank)
# kv_cache, ip + ":" + str(port + i))
# elif (self.enable_asymmetric_p2p):
# self.du_swift_engine.send_tensor(request_id + "#" + layer_name, # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# kv_cache, remote_address) # kv_cache, remote_address)
# else: # else:
# logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!") # logger.error("Error: P multiple machines D machine only suppprt P:pp2tp8 D:tp8 !!!!!!")
else:
logger.error("Error: not support!!!!!!") # elif (not self.multiple_machines_p and not self.multiple_machines_d):
# # remote_addr = RemoteAddr(pd_pair_id, remote_address, self._rank + self.num_card)
# self.du_swift_engine.send_tensor_new(request_id, layer_name, kv_cache,
# is_mla)
# # if (self.pp_size == 1):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # elif (self.pp_size == 2):
# # if (pp_rank == 0):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank + 4))
# # else:
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + self._rank - 4))
# # elif (self.pp_size == 8):
# # for i in range(8):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, ip + ":" + str(port + i))
# # elif (self.enable_asymmetric_p2p):
# # self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
# # kv_cache, remote_address)
# # else:
# # logger.error("Error: P/D single machine only suppprt multiple tp:: (P: pp2tp4 D:tp8 P:pp8tp1 D:tp8) !!!!!!")
# else:
# logger.error("Error: not support!!!!!!")
pending = False
with self.du_swift_engine.req_status_cv:
if request_id not in self.du_swift_engine.req_status:
pending = True
if pending:
self.du_swift_engine.pending_tensor(request_id, layer_name,
kv_cache)
logger.info("[%d] pending for request: %s layer: %s", self._rank, request_id, layer_name)
else :
req_data = self.du_swift_engine.req_status[request_id]
assert(req_data.dst_num == len(req_data.zmq_address_and_comm_rank))
for i in range(req_data.dst_num):
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
self.du_swift_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_addr)
def wait_for_save(self): def wait_for_save(self):
pass pass
# if self.is_producer: # if self.is_producer:
......
...@@ -6,9 +6,10 @@ import os ...@@ -6,9 +6,10 @@ import os
import threading import threading
import time import time
import typing import typing
from collections import deque from collections import deque, defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass, field
import msgpack import msgpack
import torch import torch
...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str): ...@@ -72,6 +73,12 @@ def set_du_swift_context(num_channels: str):
os.environ.pop(var, None) os.environ.pop(var, None)
@dataclass
class ReqKVDest:
dst_num: int = 0
pd_pair_id: str = ""
zmq_address_and_comm_rank: list[tuple[str, int]] = field(default_factory=list)
@dataclass @dataclass
class RemoteAddr: class RemoteAddr:
pd_pair_id: str = "" pd_pair_id: str = ""
...@@ -125,6 +132,10 @@ class DuSwiftEngine: ...@@ -125,6 +132,10 @@ class DuSwiftEngine:
self.multp = int(self.remote_tp_size / self.tp_size) self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config( self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False) "enable_multiple_machines", False)
self.instance_ip = self.config.get_from_extra_config(
"instance_ip", None)
if self.instance_ip :
self.multiple_machines = False
port = int(self.config.kv_port) + port_offset port = int(self.config.kv_port) + port_offset
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
...@@ -135,6 +146,11 @@ class DuSwiftEngine: ...@@ -135,6 +146,11 @@ class DuSwiftEngine:
self.zmq_address = f"{self._hostname}:{self._port}" self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI. # The `http_port` must be consistent with the port of OpenAI.
if self.instance_ip:
self.http_address = (
f"{self.config.kv_connector_extra_config['instance_ip']}:"
f"{self.config.kv_connector_extra_config['http_port']}")
else:
self.http_address = ( self.http_address = (
f"{self._hostname}:" f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}") f"{self.config.kv_connector_extra_config['http_port']}")
...@@ -148,16 +164,27 @@ class DuSwiftEngine: ...@@ -148,16 +164,27 @@ class DuSwiftEngine:
else: else:
self.proxy_address = proxy_ip + ":" + proxy_port self.proxy_address = proxy_ip + ":" + proxy_port
self.kv_cache_layer_num = 0
self.context = zmq.Context() self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER) self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.setsockopt(zmq.RCVHWM, 10000)
self.router_socket.setsockopt(zmq.SNDHWM, 5000)
self.router_socket.setsockopt(zmq.LINGER, 0)
self.router_socket.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router_socket.setsockopt(zmq.TCP_KEEPALIVE, 1)
self.router_socket.bind(f"tcp://{self.zmq_address}") self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller() self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN) self.poller.register(self.router_socket, zmq.POLLIN)
self.req_status: dict[str, ReqKVDest] = {}
self.req_status_cv = threading.Condition()
self.send_store_cv = threading.Condition() self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition() self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition() self.recv_store_cv = threading.Condition()
self.pending_queue_cv = threading.Condition()
self.send_stream = torch.cuda.Stream() self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream()
...@@ -181,11 +208,16 @@ class DuSwiftEngine: ...@@ -181,11 +208,16 @@ class DuSwiftEngine:
# PUT or PUT_ASYNC # PUT or PUT_ASYNC
# tensor_id: torch.Tensor # tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque() self.send_queue: deque[list[Any]] = deque()
self.pending_queue: dict[str, list[list[Any]]] = defaultdict(list)
self.requests_to_release: dict[str, bool] = {}
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async, self._send_thread = threading.Thread(target=self._send_async,
daemon=True) daemon=True)
self._send_thread.start() self._send_thread.start()
self._pending_check_thread = threading.Thread(target=self._pending_check,
daemon=True)
self._pending_check_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape) # tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {} self.recv_store: dict[str, Any] = {}
...@@ -328,7 +360,7 @@ class DuSwiftEngine: ...@@ -328,7 +360,7 @@ class DuSwiftEngine:
self, self,
tensor_id: str, tensor_id: str,
tensor: torch.Tensor, tensor: torch.Tensor,
remote_address: typing.Optional[str] = None, remote_address: typing.Optional[RemoteAddr] = None,
tbo_evt = None, tbo_evt = None,
) -> bool: ) -> bool:
if remote_address is None: if remote_address is None:
...@@ -356,7 +388,7 @@ class DuSwiftEngine: ...@@ -356,7 +388,7 @@ class DuSwiftEngine:
logger.info( logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d", " buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, remote_address.zmq_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank) self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor self.send_store[tensor_id] = tensor
...@@ -364,7 +396,7 @@ class DuSwiftEngine: ...@@ -364,7 +396,7 @@ class DuSwiftEngine:
logger.debug( logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", "shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape, remote_address.zmq_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size, self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100) self.buffer_size / self.buffer_size_threshold * 100)
...@@ -417,6 +449,55 @@ class DuSwiftEngine: ...@@ -417,6 +449,55 @@ class DuSwiftEngine:
return True return True
def pending_tensor(
self,
reuqest_id: str,
layer_name: str,
tensor: torch.Tensor,
tbo_evt = None,
) -> bool:
with self.pending_queue_cv:
self.pending_queue[reuqest_id].append([layer_name, tensor, tbo_evt])
self.pending_queue_cv.notify()
return True
def unpending_tensor(
self,
request_id: str,
req_data: ReqKVDest,
) -> bool:
with self.pending_queue_cv:
tensor_list = self.pending_queue.pop(request_id)
if request_id in self.requests_to_release:
self.requests_to_release[request_id] = True
logger.info("[%d] unpending request: %s", self.rank, request_id)
if req_data.dst_num <= 0:
return False
for layer_name, tensor, tbo_evt in tensor_list:
for i in range(req_data.dst_num) :
remote_addr = RemoteAddr(req_data.pd_pair_id, *(req_data.zmq_address_and_comm_rank[i]))
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.p2p_async_send_tensor(request_id + "#" + layer_name,
tensor, remote_addr, tbo_evt)
else :
self.send_tensor(request_id + "#" + layer_name,
tensor, remote_addr)
return True
def _pending_check(self) :
while True:
with self.pending_queue_cv:
while not self.pending_queue:
self.pending_queue_cv.wait()
pending_queue = self.pending_queue.copy()
for request_id in pending_queue:
with self.req_status_cv:
if request_id not in self.req_status:
continue
req_data = self.req_status[request_id]
assert(len(req_data.zmq_address_and_comm_rank) == req_data.dst_num)
self.unpending_tensor(request_id, req_data)
def recv_tensor( def recv_tensor(
self, self,
tensor_id: str, tensor_id: str,
...@@ -475,22 +556,12 @@ class DuSwiftEngine: ...@@ -475,22 +556,12 @@ class DuSwiftEngine:
def _listen_for_requests(self): def _listen_for_requests(self):
while True: while True:
socks = dict(self.poller.poll()) socks = dict(self.poller.poll(5000))
if self.router_socket in socks: if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart() remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message) data = msgpack.loads(message)
if data["cmd"] == "NEW": if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes( logger.info(f"unexpected message from {remote_address.decode()}")
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_du_swift_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT": elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
if "tensor_split_num" in data: if "tensor_split_num" in data:
...@@ -577,6 +648,15 @@ class DuSwiftEngine: ...@@ -577,6 +648,15 @@ class DuSwiftEngine:
logger.info( logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, data["pd_pair_id"], rank) self.zmq_address, data["pd_pair_id"], rank)
elif data["cmd"] == "req_to_transfer":
with self.req_status_cv:
assert(data["request_id"] not in self.req_status)
self.req_status[data["request_id"]] = ReqKVDest(dst_num=int(data["dst_num"]), pd_pair_id=data["pd_pair_id"], zmq_address_and_comm_rank=list(zip(data["remote_address"], data["remote_rank"])))
self.req_status_cv.notify_all()
elif data["cmd"] == "req_not_to_transfer":
with self.req_status_cv:
self.req_status[data["request_id"]] = ReqKVDest(dst_num=0)
self.req_status_cv.notify_all()
elif data["cmd"] == "GET": elif data["cmd"] == "GET":
tensor_id = data["tensor_id"] tensor_id = data["tensor_id"]
with self.send_store_cv: with self.send_store_cv:
...@@ -814,20 +894,31 @@ class DuSwiftEngine: ...@@ -814,20 +894,31 @@ class DuSwiftEngine:
""" """
# Clear the buffer upon request completion. # Clear the buffer upon request completion.
requests_to_release : list[str] = []
with self.pending_queue_cv:
for request_id, release in self.requests_to_release.items():
if release :
requests_to_release.append(request_id)
self.requests_to_release.pop(request_id)
for request_id in finished_req_ids: for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers: with self.pending_queue_cv:
tensor_id = request_id + "#" + layer_name if request_id in self.pending_queue:
if tensor_id in self.recv_store: self.requests_to_release[request_id] = False
logger.info("[%d] pending request: %s", self.rank, request_id)
continue
requests_to_release.append(request_id)
for request_id in requests_to_release:
ids = self.recv_request_id_to_tensor_ids.pop(request_id, set())
with self.recv_store_cv: with self.recv_store_cv:
for tensor_id in ids:
tensor = self.recv_store.pop(tensor_id, None) tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple): if isinstance(tensor, tuple):
addr, _, _ = tensor addr, _, _ = tensor
self.pool.free(addr) self.pool.free(addr)
self.send_request_id_to_tensor_ids.pop(request_id, None)
# TODO:Retrieve requests that have already sent the KV cache. # TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set() finished_sending: set[str] = set()
...@@ -835,6 +926,19 @@ class DuSwiftEngine: ...@@ -835,6 +926,19 @@ class DuSwiftEngine:
# TODO:Retrieve requests that have already received the KV cache. # TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set() finished_recving: set[str] = set()
if self.kv_cache_layer_num == 0 :
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
kv_cache = getattr(layer, 'kv_cache', None)
if kv_cache is None:
continue
self.kv_cache_layer_num += 1
with self.recv_store_cv:
for req in self.recv_request_id_to_tensor_ids:
if len(self.recv_request_id_to_tensor_ids[req]) == self.kv_cache_layer_num:
finished_recving.add(req)
return finished_sending or None, finished_recving or None return finished_sending or None, finished_recving or None
def _ping(self): def _ping(self):
...@@ -911,6 +1015,7 @@ class DuSwiftEngine: ...@@ -911,6 +1015,7 @@ class DuSwiftEngine:
self._listener_thread.join() self._listener_thread.join()
if self.send_type == "PUT_ASYNC": if self.send_type == "PUT_ASYNC":
self._send_thread.join() self._send_thread.join()
self._pending_check_thread.join()
if self._ping_thread is not None: if self._ping_thread is not None:
self._ping_thread.join() self._ping_thread.join()
......
...@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface): ...@@ -86,7 +86,9 @@ class Scheduler(SchedulerInterface):
"Multiple KV cache groups are not currently supported " "Multiple KV cache groups are not currently supported "
"with KV connectors") "with KV connectors")
self.connector = KVConnectorFactory.create_connector_v1( self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER,
dp_rank=self.parallel_config.data_parallel_rank)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config, self.kv_events_config,
...@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface): ...@@ -380,6 +382,10 @@ class Scheduler(SchedulerInterface):
if request.is_finished(): if request.is_finished():
self.waiting.pop_request() self.waiting.pop_request()
continue continue
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface): ...@@ -673,7 +679,10 @@ class Scheduler(SchedulerInterface):
+ len(scheduled_running_reqs) >= max_batch_running): + len(scheduled_running_reqs) >= max_batch_running):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
if self.connector and not self.connector.is_producer and request.request_id not in self.finished_recving_kv_req_ids :
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
...@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface): ...@@ -1326,7 +1335,7 @@ class Scheduler(SchedulerInterface):
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request. # Add newly generated spec token ids to the request.
if spec_token_ids is not None: if spec_token_ids is not None and (self.connector is None or not self.connector.is_producer):
if self.structured_output_manager.should_advance(request): if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted. # Needs to happen after new_token_ids are accepted.
......
...@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore): ...@@ -763,6 +763,10 @@ class EngineCoreProc(EngineCore):
# Push to input queue for core busy loop. # Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request)) self.input_queue.put_nowait((request_type, request))
if isinstance(request, EngineCoreRequest) and self.scheduler.connector is not None:
if request_type == EngineCoreRequestType.ADD:
self.scheduler.connector.register_req(request.request_id)
def process_output_sockets(self, output_paths: list[str], def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str], coord_output_path: Optional[str],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment