Unverified Commit d06a83fb authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Support dynamic connection and TP 16 (#5351)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 5d134401
......@@ -2,15 +2,18 @@ from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import queue
import random
import struct
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
......@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import (
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.utils import is_port_available
logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
......@@ -65,9 +80,10 @@ class TransferKVChunk:
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
decode_port: int
mooncake_session_id: str
room: int
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
......@@ -77,25 +93,24 @@ class TransferInfo:
def from_zmq(cls, msg: List[bytes]):
return cls(
endpoint=msg[0].decode("ascii"),
mooncake_session_id=msg[1].decode("ascii"),
room=int(msg[2].decode("ascii")),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[3])//8}Q", msg[3])),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_index=int(msg[6].decode("ascii")),
decode_port=int(msg[1].decode("ascii")),
mooncake_session_id=msg[2].decode("ascii"),
room=int(msg[3].decode("ascii")),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
)
KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788
class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine()
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
self.request_status: Dict[int, KVPoll] = {}
self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
......@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager):
)
return status
def sync_status_to_decode_endpoint(self, remote: str, room: int):
def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
if ":" in remote:
remote = remote.split(":")[0]
self._connect(
"tcp://"
+ remote
+ ":"
+ str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank)
).send_multipart(
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
[
str(room).encode("ascii"),
str(self.request_status[room]).encode("ascii"),
......@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager):
)
def start_prefill_thread(self):
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(sender_rank_port))
# Find available port for prefill tp
self.rank_port = find_available_ports(20000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[2].decode("ascii")
room = waiting_req_bytes[3].decode("ascii")
if room == "None":
continue
room = int(room)
......@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
)
if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.room)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
continue
if kv_chunk.is_last:
......@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(req.endpoint, req.room)
self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.transfer_infos.pop(req.room)
except queue.Empty:
......@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self):
receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(receiver_rank_port))
self.rank_port = find_available_ports(25000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))
def decode_thread():
while True:
......@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender):
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
# Register to bootstrap server
self._register_to_bootstrap()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
url = f"http://{self.bootstrap_server_url}/kv_route"
payload = {
"identity": self.session_id,
"role": "Prefill",
"serve_ip": self.kv_mgr.get_localhost(),
"serve_port": self.kv_mgr.rank_port,
"tp_rank": self.kv_mgr.kv_args.engine_rank,
}
logger.info(
f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.info(f"Prefill successfully registered to bootstrap server.")
else:
logger.info(
f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.info(f"Prefill Failed to register to bootstrap server: {e}")
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
......@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
)
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.prefill_engine_rank = None
self.decode_port = self.kv_mgr.rank_port
self.dealer_socket = None
def _get_prefill_info_from_bootstrap(self, tp_rank: int):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}"
response = requests.get(url)
if response.status_code == 200:
prefill_info = response.json()
return prefill_info
else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
@cache
def _connect(self, endpoint: str):
......@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
prefill_info = None
logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.")
if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool:
prefill_info = self._get_prefill_info_from_bootstrap(
self.kv_mgr.kv_args.engine_rank
)
if prefill_info is None:
logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
self.handshake_prefill_server(kv_indices, aux_index)
def handshake_prefill_server(
self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None
):
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
......@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
str(self.decode_port).encode("ascii"),
self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"),
packed_kv_data_ptrs,
......@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
# prefill_engine_rank -> prefill_info
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
self.context = zmq.Context()
self.prefill_engine_rank = None
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
......@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/kv_route", self._handle_kv_route)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
if request.method == "GET":
return await self._handle_get(key)
return await self._handle_metadata_get(key)
elif request.method == "PUT":
return await self._handle_put(key, request)
return await self._handle_metadata_put(key, request)
elif request.method == "DELETE":
return await self._handle_delete(key)
return await self._handle_metadata_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_get(self, key):
async def _handle_metadata_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
......@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
)
return web.Response(body=value, status=200, content_type="application/json")
async def _handle_put(self, key, request):
async def _handle_metadata_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
......@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata updated", status=200, content_type="application/json"
)
async def _handle_delete(self, key):
async def _handle_metadata_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
......@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata deleted", status=200, content_type="application/json"
)
async def _handle_kv_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_kv_route_put(request)
elif method == "GET":
return await self._handle_kv_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_kv_route_put(self, request: web.Request):
data = await request.json()
identity = data["identity"]
role = data["role"]
serve_ip = data["serve_ip"]
serve_port = int(data["serve_port"]) # Assuming serve_port is an integer
tp_rank = int(data["tp_rank"])
# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
return web.Response(text="OK", status=200)
async def _handle_kv_route_get(self, request: web.Request):
tp_rank = request.query.get("tp_rank")
if not tp_rank:
return web.Response(text="Missing tp_rank", status=400)
try:
tp_rank = int(tp_rank)
except ValueError:
return web.Response(text="tp_rank must be int", status=400)
# Find corresponding prefill info
async with self.lock:
prefill_info = self.prefill_port_table.get(tp_rank)
if prefill_info is not None:
return web.json_response(prefill_info, status=200)
else:
return web.Response(text="Not Found", status=404)
def _run_server(self):
try:
# Event Loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment