Unverified Commit ffde65a0 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Fix dynamic port support and MLA buffer for Mooncake (#5415)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarybyang <ybyang7@iflytek.com>
parent 471650de
......@@ -5,6 +5,7 @@ import numpy as np
import numpy.typing as npt
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
class KVArgs:
......@@ -16,6 +17,7 @@ class KVArgs:
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
gpu_id: int
class KVPoll:
......@@ -30,7 +32,12 @@ class BaseKVManager(ABC):
"""Base class for managing transfers states"""
@abstractmethod
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ...
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
): ...
class BaseKVSender(ABC):
......
......@@ -128,8 +128,11 @@ class DecodePreallocQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE)
kv_manager = kv_manager_class(
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
)
return kv_manager
def add(self, req: Req) -> None:
......
......@@ -2,10 +2,9 @@ from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import queue
import random
import socket
import struct
import threading
from functools import cache
......@@ -27,24 +26,12 @@ from sglang.srt.disaggregation.base.conn import (
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.utils import is_port_available
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__)
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
......@@ -82,10 +69,10 @@ class TransferKVChunk:
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
decode_port: int
dst_port: int
mooncake_session_id: str
room: int
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
......@@ -94,10 +81,10 @@ class TransferInfo:
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
endpoint=msg[0].decode("ascii"),
decode_port=int(msg[1].decode("ascii")),
mooncake_session_id=msg[2].decode("ascii"),
room=int(msg[3].decode("ascii")),
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
......@@ -106,12 +93,20 @@ class TransferInfo:
class MooncakeKVManager(BaseKVManager):
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
):
self.engine = MooncakeTransferEngine()
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.request_status: Dict[int, KVPoll] = {}
self.connection_pool: Dict[int, Dict[str, Union[str, int]]] = {}
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
......@@ -119,6 +114,7 @@ class MooncakeKVManager(BaseKVManager):
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread()
else:
......@@ -150,54 +146,29 @@ class MooncakeKVManager(BaseKVManager):
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
):
layer_num = int(len(self.kv_args.kv_data_ptrs) / 2)
# group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
for layer_id in range(layer_num):
prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id]
key_item_len = self.kv_args.kv_item_lens[layer_id]
prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id]
value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id]
decode_key_layer_ptr = dst_kv_ptrs[layer_id]
decode_value_layer_ptr = dst_kv_ptrs[layer_num + layer_id]
num_layers = len(self.kv_args.kv_data_ptrs)
for layer_id in range(num_layers):
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
prefill_key_addr = (
prefill_key_layer_ptr + int(prefill_index[0]) * key_item_len
)
decode_key_addr = (
decode_key_layer_ptr + int(decode_index[0]) * key_item_len
)
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
# TODO: mooncake transfer engine can do async transfer. Do async later
# TODO: make async later
status = self.engine.transfer_sync(
mooncake_session_id,
prefill_key_addr,
decode_key_addr,
key_item_len * len(prefill_index),
mooncake_session_id, src_addr, dst_addr, length
)
if status != 0:
return status
prefill_value_addr = (
prefill_value_layer_ptr + int(prefill_index[0]) * value_item_len
)
decode_value_addr = (
decode_value_layer_ptr + int(decode_index[0]) * value_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later
status = self.engine.transfer_sync(
mooncake_session_id,
prefill_value_addr,
decode_value_addr,
value_item_len * len(prefill_index),
)
if status != 0:
return status
return 0
def send_aux(
......@@ -230,16 +201,15 @@ class MooncakeKVManager(BaseKVManager):
)
def start_prefill_thread(self):
# Find available port for prefill tp
self.rank_port = find_available_ports(20000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[3].decode("ascii")
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
room = int(room)
......@@ -295,8 +265,8 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self):
self.rank_port = find_available_ports(25000, 1)[0]
self.server_socket.bind("tcp://*:" + str(self.rank_port))
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
def decode_thread():
while True:
......@@ -343,54 +313,48 @@ class MooncakeKVManager(BaseKVManager):
self.request_status[bootstrap_room], status
)
def get_localhost(self):
return self.engine.get_localhost()
def get_session_id(self):
return self.engine.get_session_id()
class MooncakeKVSender(BaseKVSender):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
# Register to bootstrap server
self._register_to_bootstrap()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
url = f"http://{self.bootstrap_server_url}/kv_route"
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"identity": self.session_id,
"role": "Prefill",
"serve_ip": self.kv_mgr.get_localhost(),
"serve_port": self.kv_mgr.rank_port,
"tp_rank": self.kv_mgr.kv_args.engine_rank,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"bootstrap_key": f"{bootstrap_server_url}_{self.kv_args.engine_rank}",
}
logger.info(
f"Register prefill server port {self.kv_mgr.rank_port} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.info(f"Prefill successfully registered to bootstrap server.")
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.info(
f"Prefill Failed to register to bootstrap server: {response.status_code}, {response.text}"
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.info(f"Prefill Failed to register to bootstrap server: {e}")
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
class MooncakeKVSender(BaseKVSender):
def __init__(
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
......@@ -433,21 +397,35 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if self.bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.bootstrap_key
)
if self.bootstrap_info is None:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[self.bootstrap_key] = self.bootstrap_info
else:
self.bootstrap_info = self.kv_mgr.connection_pool[self.bootstrap_key]
assert self.bootstrap_info is not None
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.prefill_engine_rank = None
self.decode_port = self.kv_mgr.rank_port
self.dealer_socket = None
def _get_prefill_info_from_bootstrap(self, tp_rank: int):
"""Fetch the prefill server port corresponding to tp_rank from the bootstrap server."""
def _get_bootstrap_info_from_server(self, bootstrap_key: str):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/kv_route?tp_rank={tp_rank}"
url = f"http://{self.bootstrap_addr}/route?bootstrap_key={bootstrap_key}"
response = requests.get(url)
if response.status_code == 200:
prefill_info = response.json()
return prefill_info
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
......@@ -464,39 +442,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
prefill_info = None
logger.info(f"Decode bootstrap addr {self.bootstrap_addr}.")
if self.kv_mgr.kv_args.engine_rank not in self.kv_mgr.connection_pool:
prefill_info = self._get_prefill_info_from_bootstrap(
self.kv_mgr.kv_args.engine_rank
)
if prefill_info is None:
logger.error(
logger.error(
f"Could not fetch prefill server info for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
)
else:
self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank] = (
prefill_info
)
else:
prefill_info = self.kv_mgr.connection_pool[self.kv_mgr.kv_args.engine_rank]
if prefill_info:
self.prefill_server_url = (
f"{prefill_info['serve_ip']}:{prefill_info['serve_port']}"
)
logger.info(
f"Fetched prefill server info: {prefill_info} for tp_rank {self.kv_mgr.kv_args.engine_rank}"
)
self.handshake_prefill_server(kv_indices, aux_index)
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
def handshake_prefill_server(
self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None
):
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
......@@ -505,10 +457,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
str(self.decode_port).encode("ascii"),
self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
......@@ -530,10 +482,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
# prefill_engine_rank -> prefill_info
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
self.context = zmq.Context()
self.prefill_port_table: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_engine_rank = None
......@@ -546,7 +495,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/kv_route", self._handle_kv_route)
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
......@@ -591,54 +540,47 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
text="metadata deleted", status=200, content_type="application/json"
)
async def _handle_kv_route(self, request: web.Request):
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_kv_route_put(request)
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_kv_route_get(request)
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_kv_route_put(self, request: web.Request):
async def _handle_route_put(self, request: web.Request):
data = await request.json()
identity = data["identity"]
role = data["role"]
serve_ip = data["serve_ip"]
serve_port = int(data["serve_port"]) # Assuming serve_port is an integer
tp_rank = int(data["tp_rank"])
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
bootstrap_key = data["bootstrap_key"]
# Add lock to make sure thread-safe
if role == "Prefill":
async with self.lock:
self.prefill_port_table[tp_rank] = {
"serve_ip": serve_ip,
"serve_port": serve_port,
}
logger.info(
f"Registered Prefill tp_rank: {tp_rank} with serve_ip: {serve_ip} and serve_port: {serve_port}"
self.prefill_port_table[bootstrap_key] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Registered Prefill bootstrap_key: {bootstrap_key} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_kv_route_get(self, request: web.Request):
tp_rank = request.query.get("tp_rank")
if not tp_rank:
return web.Response(text="Missing tp_rank", status=400)
try:
tp_rank = int(tp_rank)
except ValueError:
return web.Response(text="tp_rank must be int", status=400)
async def _handle_route_get(self, request: web.Request):
bootstrap_key = request.query.get("bootstrap_key")
if not bootstrap_key:
return web.Response(text="Missing bootstrap_key", status=400)
# Find corresponding prefill info
async with self.lock:
prefill_info = self.prefill_port_table.get(tp_rank)
if prefill_info is not None:
return web.json_response(prefill_info, status=200)
bootstrap_info = self.prefill_port_table.get(bootstrap_key)
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Not Found", status=404)
......
......@@ -67,6 +67,7 @@ class PrefillBootstrapQueue:
bootstrap_port: int,
gloo_group: ProcessGroup,
transfer_backend: TransferBackend,
scheduler: Scheduler,
):
self.token_to_kv_pool = token_to_kv_pool
self.aux_dtype = aux_dtype
......@@ -76,6 +77,7 @@ class PrefillBootstrapQueue:
self.tp_rank = tp_rank
self.tp_size = tp_size
self.transfer_backend = transfer_backend
self.scheduler = scheduler
self.kv_manager = self._init_kv_manager()
self.queue: List[Req] = []
self.gloo_group = gloo_group
......@@ -108,8 +110,11 @@ class PrefillBootstrapQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.PREFILL)
kv_manager = kv_manager_class(
kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
)
return kv_manager
def add(self, req: Req) -> None:
......
......@@ -599,6 +599,7 @@ class Scheduler(
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
transfer_backend=self.transfer_backend,
scheduler=self,
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
......
......@@ -1872,3 +1872,36 @@ def is_hopper_with_cuda_12_3():
cuda_version = torch.version.cuda.split(".")
is_cuda_compatible = int(cuda_version[0]) == 12 and int(cuda_version[1]) >= 3
return is_hopper and is_cuda_compatible
def get_free_port():
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip_by_remote() -> str:
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
raise ValueError(f"Can not get local ip")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment