Unverified Commit a13dd1e4 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Improve disaggregation common backend and refactor mooncake backend (#10273)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent d500eb91
......@@ -22,12 +22,18 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
......@@ -50,30 +56,57 @@ class CommonKVManager(BaseKVManager):
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.attn_dp_size = get_attention_dp_size()
self.attn_dp_rank = get_attention_dp_rank()
self.system_dp_size = (
1 if server_args.enable_dp_attention else server_args.dp_size
)
self.system_dp_rank = (
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
)
self.pp_size = server_args.pp_size
self.pp_rank = self.kv_args.pp_rank
self.rank_port = get_free_port()
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.request_status: Dict[int, KVPoll] = {}
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.pp_group = get_pp_group()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.connection_lock = threading.Lock()
self.required_prefill_response_num_table: Dict[int, int] = {}
self.prefill_attn_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
self.prefill_pp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
# multi node: bootstrap server's host is dist_init_addr
# Multi-node case: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
......@@ -82,7 +115,7 @@ class CommonKVManager(BaseKVManager):
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
# single node: bootstrap server's host is same as http server's host
# Single-node case: bootstrap server's host is the same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host)
......@@ -90,23 +123,30 @@ class CommonKVManager(BaseKVManager):
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"attn_tp_size": self.attn_tp_size,
"attn_tp_rank": self.attn_tp_rank,
"attn_dp_size": self.attn_dp_size,
"attn_dp_rank": self.attn_dp_rank,
"pp_size": self.pp_size,
"pp_rank": self.pp_rank,
"system_dp_size": self.system_dp_size,
"system_dp_rank": self.system_dp_rank,
"rank_ip": self.local_ip,
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
try:
response = requests.put(url, json=payload)
response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
logger.error(
f"Prefill instance failed to register to bootstrap server: {e}"
)
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
......@@ -117,6 +157,41 @@ class CommonKVManager(BaseKVManager):
return socket
class CommonKVSender(BaseKVSender):
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: int,
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
# inner state
self.curr_idx = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
kv_indices: npt.NDArray[np.int32],
):
pass
def poll(self) -> KVPoll:
pass
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class CommonKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
......@@ -133,61 +208,88 @@ class CommonKVReceiver(BaseKVReceiver):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
(
self.prefill_attn_tp_size,
self.prefill_dp_size,
self.prefill_pp_size,
) = self._get_prefill_parallel_info_from_server()
if (
self.prefill_attn_tp_size is None
or self.prefill_dp_size is None
or self.prefill_pp_size is None
):
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
)
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
self.prefill_attn_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
self.prefill_pp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
)
self.required_dst_info_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
)
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
else:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
)
]
......@@ -196,6 +298,14 @@ class CommonKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
self.prefill_pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
......@@ -206,6 +316,9 @@ class CommonKVReceiver(BaseKVReceiver):
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
)
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
......@@ -214,41 +327,49 @@ class CommonKVReceiver(BaseKVReceiver):
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank,
self.target_dp_group,
)
if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
for target_pp_rank in range(self.prefill_pp_size):
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank, self.target_dp_group, target_pp_rank
)
if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend:
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
else:
# For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
)
bootstrap_infos.append(bootstrap_info)
else:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
def _get_bootstrap_info_from_server(
self, engine_rank, target_dp_group, target_pp_rank
):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}&target_pp_rank={target_pp_rank}"
response = requests.get(url, timeout=5)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
......@@ -261,24 +382,28 @@ class CommonKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
def _get_prefill_parallel_info_from_server(
self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
return (
int(prefill_parallel_info["prefill_attn_tp_size"]),
int(prefill_parallel_info["prefill_dp_size"]),
int(prefill_parallel_info["prefill_pp_size"]),
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
return None, None, None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
return None, None, None
@classmethod
def _connect(cls, endpoint: str, is_ipv6: bool = False):
......@@ -317,10 +442,12 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.pp_size = None
self.attn_tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
self.prefill_port_table: Dict[
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
......@@ -331,6 +458,10 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
self.app.router.add_get("/health", self._handle_health_check)
async def _handle_health_check(self, request):
return web.Response(text="OK", status=200)
async def _handle_route(self, request: web.Request):
method = request.method
......@@ -346,37 +477,45 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
attn_tp_size = data["attn_tp_size"]
attn_tp_rank = data["attn_tp_rank"]
attn_dp_size = data["attn_dp_size"]
attn_dp_rank = data["attn_dp_rank"]
pp_size = data["pp_size"]
pp_rank = data["pp_rank"]
system_dp_size = data["system_dp_size"]
system_dp_rank = data["system_dp_rank"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None:
self.tp_size = tp_size
if self.attn_tp_size is None:
self.attn_tp_size = attn_tp_size
if self.dp_size is None:
self.dp_size = dp_size
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
if self.pp_size is None:
self.pp_size = pp_size
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
if system_dp_size == 1:
dp_group = attn_dp_rank
else:
dp_group = system_dp_rank
# Add lock to make sure thread-safe
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
if attn_tp_rank not in self.prefill_port_table[dp_group]:
self.prefill_port_table[dp_group][attn_tp_rank] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
......@@ -384,14 +523,20 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
target_pp_rank = request.query.get("target_pp_rank")
if not engine_rank or not target_dp_group or not target_pp_rank:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
if (
int(engine_rank) == -1
and int(target_dp_group) == -1
and int(target_pp_rank) == -1
):
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_attn_tp_size": self.attn_tp_size,
"prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
}
return web.json_response(prefill_parallel_info, status=200)
......@@ -399,7 +544,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
]
][int(target_pp_rank)]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
......@@ -412,7 +557,11 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
access_log = None
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
access_log = self.app.logger
self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, host=self.host, port=self.port)
......
from __future__ import annotations
import asyncio
import concurrent.futures
import ctypes
import dataclasses
import logging
import os
import queue
import socket
import struct
import threading
import time
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
CommonKVSender,
)
from sglang.srt.disaggregation.common.utils import (
FastQueue,
......@@ -35,23 +29,12 @@ from sglang.srt.disaggregation.common.utils import (
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_bool_env_var,
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -159,7 +142,7 @@ class AuxDataCodec:
return
class MooncakeKVManager(BaseKVManager):
class MooncakeKVManager(CommonKVManager):
AUX_DATA_HEADER = b"AUX_DATA"
def __init__(
......@@ -169,43 +152,14 @@ class MooncakeKVManager(BaseKVManager):
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.local_ip = get_local_ip_auto()
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
self.init_engine()
# for p/d multi node infer
self.bootstrap_host = server_args.host
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.attn_dp_size = get_attention_dp_size()
self.attn_dp_rank = get_attention_dp_rank()
self.system_dp_size = (
1 if server_args.enable_dp_attention else server_args.dp_size
)
self.system_dp_rank = (
self.kv_args.system_dp_rank if self.kv_args.system_dp_rank else 0
)
self.pp_size = server_args.pp_size
self.pp_rank = self.kv_args.pp_rank
self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()
self.pp_group = get_pp_group()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
transfer_thread_pool_size = get_int_env_var(
......@@ -245,8 +199,6 @@ class MooncakeKVManager(BaseKVManager):
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock()
self.required_prefill_response_num_table: Dict[int, int] = {}
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
......@@ -257,20 +209,12 @@ class MooncakeKVManager(BaseKVManager):
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
)
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_attn_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
self.prefill_pp_size_table: Dict[str, int] = {}
# If a timeout happens on the decode side, it means decode instances
# fail to receive the KV Cache transfer done signal after bootstrapping.
# These timeout requests should be aborted to release the tree cache.
self.waiting_timeout = get_int_env_var(
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
)
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock()
......@@ -295,14 +239,6 @@ class MooncakeKVManager(BaseKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
@cache
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
......@@ -654,6 +590,26 @@ class MooncakeKVManager(BaseKVManager):
]
)
def _handle_aux_data(self, msg: List[bytes]):
"""Handle AUX_DATA messages received by the decode thread."""
room = int(msg[1].decode("ascii"))
buffer_index = int(msg[2].decode("ascii"))
aux_index = int(msg[3].decode("ascii"))
data_length = struct.unpack(">I", msg[4])[0]
data = msg[5]
if len(data) != data_length:
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
return
AuxDataCodec.deserialize_data_to_buffer(
self.kv_args, buffer_index, aux_index, data
)
logger.debug(
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
......@@ -802,11 +758,7 @@ class MooncakeKVManager(BaseKVManager):
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def start_prefill_thread(self):
self.rank_port = get_free_port()
self._bind_server_socket()
def bootstrap_thread():
......@@ -844,28 +796,7 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=bootstrap_thread).start()
def _handle_aux_data(self, msg: List[bytes]):
"""Handle AUX_DATA messages received by the decode thread."""
room = int(msg[1].decode("ascii"))
buffer_index = int(msg[2].decode("ascii"))
aux_index = int(msg[3].decode("ascii"))
data_length = struct.unpack(">I", msg[4])[0]
data = msg[5]
if len(data) != data_length:
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
return
AuxDataCodec.deserialize_data_to_buffer(
self.kv_args, buffer_index, aux_index, data
)
logger.debug(
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def start_decode_thread(self):
self.rank_port = get_free_port()
self._bind_server_socket()
def decode_thread():
......@@ -1020,51 +951,6 @@ class MooncakeKVManager(BaseKVManager):
def get_session_id(self):
return self.engine.get_session_id()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
# multi node case: bootstrap server's host is dist_init_addr
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
# single node case: bootstrap server's host is same as http server's host
host = self.bootstrap_host
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"attn_tp_size": self.attn_tp_size,
"attn_tp_rank": self.attn_tp_rank,
"attn_dp_size": self.attn_dp_size,
"attn_dp_rank": self.attn_dp_rank,
"pp_size": self.pp_size,
"pp_rank": self.pp_rank,
"system_dp_size": self.system_dp_size,
"system_dp_rank": self.system_dp_rank,
"rank_ip": self.local_ip,
"rank_port": self.rank_port,
}
try:
response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(
f"Prefill instance failed to register to bootstrap server: {e}"
)
def _handle_node_failure(self, failed_bootstrap_addr):
with self.connection_lock:
keys_to_remove = [
......@@ -1103,7 +989,7 @@ class MooncakeKVManager(BaseKVManager):
)
class MooncakeKVSender(BaseKVSender):
class MooncakeKVSender(CommonKVSender):
def __init__(
self,
......@@ -1113,19 +999,9 @@ class MooncakeKVSender(BaseKVSender):
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.conclude_state = None
self.init_time = time.time()
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
......@@ -1203,7 +1079,7 @@ class MooncakeKVSender(BaseKVSender):
self.conclude_state = KVPoll.Failed
class MooncakeKVReceiver(BaseKVReceiver):
class MooncakeKVReceiver(CommonKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
......@@ -1216,166 +1092,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
bootstrap_room: Optional[int] = None,
prefill_dp_rank: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.session_id = mgr.get_session_id()
self.conclude_state = None
self.init_time = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
(
self.prefill_attn_tp_size,
self.prefill_dp_size,
self.prefill_pp_size,
) = self._get_prefill_parallel_info_from_server()
if (
self.prefill_attn_tp_size is None
or self.prefill_dp_size is None
or self.prefill_pp_size is None
):
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
else:
logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_attn_tp_size} PP size:{self.prefill_pp_size}"
)
self.kv_mgr.prefill_attn_tp_size_table[self.bootstrap_addr] = (
self.prefill_attn_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
self.kv_mgr.prefill_pp_size_table[self.bootstrap_addr] = (
self.prefill_pp_size
)
else:
self.prefill_attn_tp_size = self.kv_mgr.prefill_attn_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
self.prefill_pp_size = self.kv_mgr.prefill_pp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
if self.kv_mgr.attn_tp_size == self.prefill_attn_tp_size:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
)
self.required_dst_info_num = 1
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
) // (self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size)
self.required_dst_info_num = (
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
)
self.required_prefill_response_num = 1 * (
self.prefill_pp_size // self.kv_mgr.pp_size
)
self.target_tp_ranks = [self.target_tp_rank]
else:
if not self.kv_mgr.is_mla_backend:
logger.warning_once(
"Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
(self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size + 1)
* (self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
if self.kv_mgr.is_mla_backend:
self.required_prefill_response_num = (
self.prefill_pp_size // self.kv_mgr.pp_size
)
else:
self.required_prefill_response_num = (
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
if prefill_dp_rank is not None:
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
self.prefill_dp_rank = prefill_dp_rank
else:
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
# FIXME: alias here: target_dp_group -> prefill_dp_rank
self.target_dp_group = self.prefill_dp_rank
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
self.required_prefill_response_num
)
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
for target_pp_rank in range(self.prefill_pp_size):
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank, self.target_dp_group, target_pp_rank
)
if bootstrap_info is not None:
if self.kv_mgr.is_mla_backend:
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
else:
# For non-MLA: all target_tp_ranks are selected real ranks
bootstrap_info["is_dummy"] = False
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank} PP {target_pp_rank}"
)
bootstrap_infos.append(bootstrap_info)
else:
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group} and target_pp_rank {target_pp_rank}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(self.bootstrap_room)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
......@@ -1398,29 +1119,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_parallel_info_from_server(
self,
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}&target_pp_rank={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return (
int(prefill_parallel_info["prefill_attn_tp_size"]),
int(prefill_parallel_info["prefill_dp_size"]),
int(prefill_parallel_info["prefill_pp_size"]),
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None, None, None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None, None, None
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
packed_kv_data_ptrs = b"".join(
......@@ -1452,28 +1150,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
]
)
@classmethod
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
......@@ -1551,154 +1227,5 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.conclude_state = KVPoll.Failed
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, host: str, port: int):
self.host = host
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.pp_size = None
self.attn_tp_size = None
self.dp_size = None
self.prefill_port_table: Dict[
int, Dict[int, Dict[int, Dict[str, Union[str, int]]]]
] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
self.app.router.add_get("/health", self._handle_health_check)
async def _handle_health_check(self, request):
return web.Response(text="OK", status=200)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
attn_tp_size = data["attn_tp_size"]
attn_tp_rank = data["attn_tp_rank"]
attn_dp_size = data["attn_dp_size"]
attn_dp_rank = data["attn_dp_rank"]
pp_size = data["pp_size"]
pp_rank = data["pp_rank"]
system_dp_size = data["system_dp_size"]
system_dp_rank = data["system_dp_rank"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
if self.attn_tp_size is None:
self.attn_tp_size = attn_tp_size
if self.dp_size is None:
self.dp_size = attn_dp_size if system_dp_size == 1 else system_dp_size
if self.pp_size is None:
self.pp_size = pp_size
if role == "Prefill":
if system_dp_size == 1:
dp_group = attn_dp_rank
else:
dp_group = system_dp_rank
# Add lock to make sure thread-safe
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
if attn_tp_rank not in self.prefill_port_table[dp_group]:
self.prefill_port_table[dp_group][attn_tp_rank] = {}
self.prefill_port_table[dp_group][attn_tp_rank][pp_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
target_pp_rank = request.query.get("target_pp_rank")
if not engine_rank or not target_dp_group or not target_pp_rank:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if (
int(engine_rank) == -1
and int(target_dp_group) == -1
and int(target_pp_rank) == -1
):
prefill_parallel_info = {
"prefill_attn_tp_size": self.attn_tp_size,
"prefill_dp_size": self.dp_size,
"prefill_pp_size": self.pp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
][int(target_pp_rank)]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
access_log = None
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
access_log = self.app.logger
self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, host=self.host, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...
class MooncakeKVBootstrapServer(CommonKVBootstrapServer):
pass
from __future__ import annotations
import asyncio
import dataclasses
import logging
import queue
import socket
import struct
import threading
import uuid
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
from typing import Dict, List, Optional, Set
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import (
CommonKVBootstrapServer,
CommonKVManager,
CommonKVReceiver,
CommonKVSender,
)
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -134,16 +123,9 @@ class NixlKVManager(CommonKVManager):
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status: Dict[int, KVPoll] = {}
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self._start_bootstrap_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
......@@ -166,6 +148,9 @@ class NixlKVManager(CommonKVManager):
self.request_status[bootstrap_room], status
)
def record_failure(self, bootstrap_room: int, failure_reason: str):
pass
def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
......@@ -438,7 +423,7 @@ class NixlKVManager(CommonKVManager):
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
if decode_tp_size == self.tp_size:
if decode_tp_size == self.attn_tp_size:
kv_xfer_handle = self.send_kvcache(
req.agent_name,
kv_indices,
......@@ -455,7 +440,7 @@ class NixlKVManager(CommonKVManager):
chunked_dst_kv_indice,
self.decode_kv_args_table[req.agent_name].gpu_id,
notif,
prefill_tp_size=self.tp_size,
prefill_tp_size=self.attn_tp_size,
decode_tp_size=decode_tp_size,
decode_tp_rank=self.decode_kv_args_table[
req.agent_name
......@@ -505,9 +490,6 @@ class NixlKVManager(CommonKVManager):
return False
return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self):
self._bind_server_socket()
......@@ -548,7 +530,7 @@ class NixlKVManager(CommonKVManager):
threading.Thread(target=bootstrap_thread).start()
class NixlKVSender(BaseKVSender):
class NixlKVSender(CommonKVSender):
def __init__(
self,
......@@ -558,20 +540,10 @@ class NixlKVSender(BaseKVSender):
dest_tp_ranks: List[int],
pp_rank: int,
):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment