Unverified Commit d06a83fb authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Support dynamic connection and TP 16 (#5351)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 5d134401
...@@ -2,15 +2,18 @@ from __future__ import annotations ...@@ -2,15 +2,18 @@ from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import json
import logging import logging
import queue import queue
import random
import struct import struct
import threading import threading
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import requests
import zmq import zmq
from aiohttp import web from aiohttp import web
...@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -24,9 +27,21 @@ from sglang.srt.disaggregation.base.conn import (
) )
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.utils import is_port_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
...@@ -65,9 +80,10 @@ class TransferKVChunk: ...@@ -65,9 +80,10 @@ class TransferKVChunk:
@dataclasses.dataclass @dataclasses.dataclass
class TransferInfo: class TransferInfo:
room: int
endpoint: str endpoint: str
decode_port: int
mooncake_session_id: str mooncake_session_id: str
room: int
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
...@@ -77,25 +93,24 @@ class TransferInfo: ...@@ -77,25 +93,24 @@ class TransferInfo:
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
return cls( return cls(
endpoint=msg[0].decode("ascii"), endpoint=msg[0].decode("ascii"),
mooncake_session_id=msg[1].decode("ascii"), decode_port=int(msg[1].decode("ascii")),
room=int(msg[2].decode("ascii")), mooncake_session_id=msg[2].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[3])//8}Q", msg[3])), room=int(msg[3].decode("ascii")),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_index=int(msg[6].decode("ascii")), dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
) )
KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788
class MooncakeKVManager(BaseKVManager): class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine() self.engine = MooncakeTransferEngine()
self.kv_args = args self.kv_args = args
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
self.request_status: Dict[int, KVPoll] = {} self.request_status: Dict[int, KVPoll] = {}
self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager): ...@@ -202,15 +217,10 @@ class MooncakeKVManager(BaseKVManager):
) )
return status return status
def sync_status_to_decode_endpoint(self, remote: str, room: int): def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
if ":" in remote: if ":" in remote:
remote = remote.split(":")[0] remote = remote.split(":")[0]
self._connect( self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
"tcp://"
+ remote
+ ":"
+ str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank)
).send_multipart(
[ [
str(room).encode("ascii"), str(room).encode("ascii"),
str(self.request_status[room]).encode("ascii"), str(self.request_status[room]).encode("ascii"),
...@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager): ...@@ -218,15 +228,16 @@ class MooncakeKVManager(BaseKVManager):
) )
def start_prefill_thread(self): def start_prefill_thread(self):
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank # Find available port for prefill tp
self.server_socket.bind("tcp://*:" + str(sender_rank_port)) self.rank_port = find_available_ports(20000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))
def bootstrap_thread(): def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine""" """This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput # KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True: while True:
waiting_req_bytes = self.server_socket.recv_multipart() waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[2].decode("ascii") room = waiting_req_bytes[3].decode("ascii")
if room == "None": if room == "None":
continue continue
room = int(room) room = int(room)
...@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -254,7 +265,7 @@ class MooncakeKVManager(BaseKVManager):
) )
if ret != 0: if ret != 0:
self.request_status[kv_chunk.room] = KVPoll.Failed self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(req.endpoint, req.room) self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
continue continue
if kv_chunk.is_last: if kv_chunk.is_last:
...@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -268,7 +279,7 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[req.room] = ( self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed KVPoll.Success if ret == 0 else KVPoll.Failed
) )
self.sync_status_to_decode_endpoint(req.endpoint, req.room) self.sync_status_to_decode_endpoint(req.endpoint, req.dst_port, req.room)
self.transfer_infos.pop(req.room) self.transfer_infos.pop(req.room)
except queue.Empty: except queue.Empty:
...@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -278,8 +289,8 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=transfer_thread).start() threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank self.rank_port = find_available_ports(25000, 1)[0]
self.server_socket.bind("tcp://*:" + str(receiver_rank_port)) self.server_socket.bind("tcp://*:" + str(self.rank_port))
def decode_thread(): def decode_thread():
while True: while True:
...@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender): ...@@ -342,6 +353,38 @@ class MooncakeKVSender(BaseKVSender):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
# Register to bootstrap server
self._register_to_bootstrap()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
url = f"http://{self.bootstrap_server_url}/kv_route"
payload = {
"identity": self.session_id,
"role": "Prefill",
"serve_ip": self.kv_mgr.get_localhost(),
"serve_port": self.kv_mgr.rank_port,
"tp_rank": self.kv_mgr.kv_args.engine_rank,
}
logger.info(
f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.info(f"Prefill successfully registered to bootstrap server.")
else:
logger.info(
f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.info(f"Prefill Failed to register to bootstrap server: {e}")
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
...@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -384,14 +427,28 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
)
self.decode_ip = self.kv_mgr.get_localhost() self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.prefill_engine_rank = None
self.decode_port = self.kv_mgr.rank_port
self.dealer_socket = None
def _get_prefill_info_from_bootstrap(self, tp_rank: int):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}"
response = requests.get(url)
if response.status_code == 200:
prefill_info = response.json()
return prefill_info
else:
logger.error(f"Failed to get prefill server info: {response.status_code}, {response.text}")
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
@cache @cache
def _connect(self, endpoint: str): def _connect(self, endpoint: str):
...@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -400,6 +457,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
return socket return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
prefill_info = None
logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.")
if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool:
prefill_info = self._get_prefill_info_from_bootstrap(
self.kv_mgr.kv_args.engine_rank
)
if prefill_info is None:
logger.error(
logger.error(f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}")
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = prefill_info
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info:
self.prefill_server_url = f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
logger.info(f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}")
self.handshake_prefill_server(kv_indices, aux_index)
def handshake_prefill_server(
self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None
):
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
...@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -409,6 +491,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self._connect("tcp://" + self.prefill_server_url).send_multipart( self._connect("tcp://" + self.prefill_server_url).send_multipart(
[ [
self.decode_ip.encode("ascii"), self.decode_ip.encode("ascii"),
str(self.decode_port).encode("ascii"),
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
...@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -432,6 +515,12 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict() self.store = dict()
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self._setup_routes() self._setup_routes()
# prefill_engine_rank -> prefill_info
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
self.context = zmq.Context()
self.prefill_engine_rank = None
# Start bootstrap server # Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True) self.thread = threading.Thread(target=self._run_server, daemon=True)
...@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -442,21 +531,22 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self): def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata) self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/kv_route", self._handle_kv_route)
async def _handle_metadata(self, request: web.Request): async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "") key = request.query.get("key", "")
if request.method == "GET": if request.method == "GET":
return await self._handle_get(key) return await self._handle_metadata_get(key)
elif request.method == "PUT": elif request.method == "PUT":
return await self._handle_put(key, request) return await self._handle_metadata_put(key, request)
elif request.method == "DELETE": elif request.method == "DELETE":
return await self._handle_delete(key) return await self._handle_metadata_delete(key)
return web.Response( return web.Response(
text="Method not allowed", status=405, content_type="application/json" text="Method not allowed", status=405, content_type="application/json"
) )
async def _handle_get(self, key): async def _handle_metadata_get(self, key):
async with self.lock: async with self.lock:
value = self.store.get(key) value = self.store.get(key)
if value is None: if value is None:
...@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -465,7 +555,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
) )
return web.Response(body=value, status=200, content_type="application/json") return web.Response(body=value, status=200, content_type="application/json")
async def _handle_put(self, key, request): async def _handle_metadata_put(self, key, request):
data = await request.read() data = await request.read()
async with self.lock: async with self.lock:
self.store[key] = data self.store[key] = data
...@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -473,7 +563,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata updated", status=200, content_type="application/json" text="metadata updated", status=200, content_type="application/json"
) )
async def _handle_delete(self, key): async def _handle_metadata_delete(self, key):
async with self.lock: async with self.lock:
if key not in self.store: if key not in self.store:
return web.Response( return web.Response(
...@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -486,6 +576,52 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata deleted", status=200, content_type="application/json" text="metadata deleted", status=200, content_type="application/json"
) )
async def _handle_kv_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_kv_route_put(request)
elif method == "GET":
return await self._handle_kv_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_kv_route_put(self, request: web.Request):
data = await request.json()
identity = data["identity"]
role = data["role"]
serve_ip = data["serve_ip"]
serve_port = int(data["serve_port"]) # Assuming serve_port is an integer
tp_rank = int(data["tp_rank"])
# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {"serve_ip": serve_ip, "serve_port": serve_port}
logger.info(f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}")
return web.Response(text="OK", status=200)
async def _handle_kv_route_get(self, request: web.Request):
tp_rank = request.query.get("tp_rank")
if not tp_rank:
return web.Response(text="Missing tp_rank", status=400)
try:
tp_rank = int(tp_rank)
except ValueError:
return web.Response(text="tp_rank must be int", status=400)
# Find corresponding prefill info
async with self.lock:
prefill_info = self.prefill_port_table.get(tp_rank)
if prefill_info is not None:
return web.json_response(prefill_info, status=200)
else:
return web.Response(text="Not Found", status=404)
def _run_server(self): def _run_server(self):
try: try:
# Event Loop # Event Loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment