Unverified Commit e806f708 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[PD] Make bootstrap code common between NIXL and Mooncake (#6473)

parent fa6723f0
...@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis ...@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
# decode 1 # decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128 $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
``` ```
## NIXL
### Requirements
Install via pip.
```bash
pip install nixl
```
Or build from source - may be required if you already have UCX installed.
```bash
git clone https://github.com/ai-dynamo/nixl.git
cd nixl
pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
```
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
from __future__ import annotations
import asyncio
import logging
import socket
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__)
class CommonKVManager(BaseKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
class CommonKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
self.required_dst_info_num = 1
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
)
self.target_tp_ranks = [self.target_tp_rank]
else:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank,
self.target_dp_group,
)
if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
)
self.bootstrap_infos = bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
@classmethod
def _connect(cls, endpoint: str):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
def _register_kv_args(self):
pass
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class CommonKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None:
self.tp_size = tp_size
if self.dp_size is None:
self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_dp_size": self.dp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...
...@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll, KVPoll,
) )
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import (
DisaggregationMode,
group_concurrent_contiguous,
)
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_free_port, get_free_port,
...@@ -41,23 +44,6 @@ from sglang.srt.utils import ( ...@@ -41,23 +44,6 @@ from sglang.srt.utils import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
class KVTransferError(Exception): class KVTransferError(Exception):
def __init__(self, bootstrap_room: int, failure_reason: str): def __init__(self, bootstrap_room: int, failure_reason: str):
super().__init__(failure_reason) super().__init__(failure_reason)
......
...@@ -18,40 +18,23 @@ import requests ...@@ -18,40 +18,23 @@ import requests
import zmq import zmq
from aiohttp import web from aiohttp import web
from sglang.srt.disaggregation.base.conn import ( from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
BaseKVBootstrapServer, from sglang.srt.disaggregation.common.conn import (
BaseKVManager, CommonKVBootstrapServer,
BaseKVReceiver, CommonKVManager,
BaseKVSender, CommonKVReceiver,
KVArgs, )
KVPoll, from sglang.srt.disaggregation.utils import (
DisaggregationMode,
group_concurrent_contiguous,
) )
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote from sglang.srt.utils import get_local_ip_by_remote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
GUARD = "NixlMsgGuard".encode("ascii") GUARD = "NixlMsgGuard".encode("ascii")
...@@ -61,11 +44,13 @@ class TransferInfo: ...@@ -61,11 +44,13 @@ class TransferInfo:
endpoint: str endpoint: str
dst_port: int dst_port: int
agent_metadata: bytes agent_metadata: bytes
agent_name: str
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_aux_index: int dst_aux_index: int
dst_gpu_id: int dst_gpu_id: int
required_dst_info_num: int
def is_dummy(self): def is_dummy(self):
return self.endpoint == "" return self.endpoint == ""
...@@ -79,11 +64,13 @@ class TransferInfo: ...@@ -79,11 +64,13 @@ class TransferInfo:
endpoint="", endpoint="",
dst_port=0, dst_port=0,
agent_metadata=b"", agent_metadata=b"",
agent_name="",
dst_kv_ptrs=[], dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64), dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[], dst_aux_ptrs=[],
dst_aux_index=0, dst_aux_index=0,
dst_gpu_id=0, dst_gpu_id=0,
required_dst_info_num=0,
) )
else: else:
return cls( return cls(
...@@ -91,11 +78,13 @@ class TransferInfo: ...@@ -91,11 +78,13 @@ class TransferInfo:
endpoint=msg[1].decode("ascii"), endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")), dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3], agent_metadata=msg[3],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), agent_name=msg[4].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])), dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_index=int(msg[7].decode("ascii")), dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_gpu_id=int(msg[8].decode("ascii")), dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
) )
...@@ -116,7 +105,7 @@ class TransferStatus: ...@@ -116,7 +105,7 @@ class TransferStatus:
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
class NixlKVManager(BaseKVManager): class NixlKVManager(CommonKVManager):
def __init__( def __init__(
self, self,
args: KVArgs, args: KVArgs,
...@@ -124,6 +113,7 @@ class NixlKVManager(BaseKVManager): ...@@ -124,6 +113,7 @@ class NixlKVManager(BaseKVManager):
server_args: ServerArgs, server_args: ServerArgs,
is_mla_backend: Optional[bool] = False, is_mla_backend: Optional[bool] = False,
): ):
super().__init__(args, disaggregation_mode, server_args, is_mla_backend)
try: try:
from nixl._api import nixl_agent from nixl._api import nixl_agent
except ImportError as e: except ImportError as e:
...@@ -133,38 +123,15 @@ class NixlKVManager(BaseKVManager): ...@@ -133,38 +123,15 @@ class NixlKVManager(BaseKVManager):
"to run SGLang with NixlTransferEngine." "to run SGLang with NixlTransferEngine."
) from e ) from e
self.agent = nixl_agent(str(uuid.uuid4())) self.agent = nixl_agent(str(uuid.uuid4()))
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.tp_rank = args.engine_rank
self.enable_dp_attention = server_args.enable_dp_attention
if self.enable_dp_attention:
assert (
server_args.dp_size > 1
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
self.dp_size = server_args.dp_size
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
self.dp_rank = args.engine_rank // self.tp_size_of_dp
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.request_status = {}
self.transfer_infos: Dict[int, TransferInfo] = {} self.transfer_infos: Dict[int, TransferInfo] = {}
self.condition = threading.Condition() self.peer_names: Dict[str, str] = {}
self.peer_names: Dict[int, str] = {}
self._start_bootstrap_thread() self._start_bootstrap_thread()
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
# bootstrap key -> (remote_engine_rank -> possible remote source info)
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus TransferStatus
) )
...@@ -173,6 +140,18 @@ class NixlKVManager(BaseKVManager): ...@@ -173,6 +140,18 @@ class NixlKVManager(BaseKVManager):
f"Unsupported DisaggregationMode: {self.disaggregation_mode}" f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
) )
def check_status(self, bootstrap_room: int):
return self.request_status[bootstrap_room]
def update_status(self, bootstrap_room: int, status: KVPoll):
if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
# NOTE: The prefill engine could recv bootstrapping first
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def register_buffer_to_engine(self): def register_buffer_to_engine(self):
kv_addrs = [] kv_addrs = []
for kv_data_ptr, kv_data_len in zip( for kv_data_ptr, kv_data_len in zip(
...@@ -193,16 +172,10 @@ class NixlKVManager(BaseKVManager): ...@@ -193,16 +172,10 @@ class NixlKVManager(BaseKVManager):
if not self.aux_descs: if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors") raise Exception("NIXL memory registration failed for aux tensors")
@cache def _add_remote(self, agent_name: str, agent_metadata: bytes):
def _connect(self, endpoint: str): if agent_name not in self.peer_names:
socket = zmq.Context().socket(zmq.PUSH) self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
socket.connect(endpoint) return self.peer_names[agent_name]
return socket
def _add_remote(self, room: int, agent_metadata: bytes):
if room not in self.peer_names:
self.peer_names[room] = self.agent.add_remote_agent(agent_metadata)
return self.peer_names[room]
def send_kvcache( def send_kvcache(
self, self,
...@@ -300,40 +273,38 @@ class NixlKVManager(BaseKVManager): ...@@ -300,40 +273,38 @@ class NixlKVManager(BaseKVManager):
assert self.disaggregation_mode == DisaggregationMode.PREFILL assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None) assert not is_last or (is_last and aux_index is not None)
# Wait for transfer info to be populated by bootstrap thread. reqs_to_be_processed = self.transfer_infos[bootstrap_room].values()
with self.condition: handles = []
self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos) for req in reqs_to_be_processed:
req = self.transfer_infos[bootstrap_room] assert bootstrap_room == req.room
assert bootstrap_room == req.room if req.is_dummy():
return []
if req.is_dummy():
return []
peer_name = self._add_remote(bootstrap_room, req.agent_metadata) peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice] chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices) assert len(chunked_dst_kv_indice) == len(kv_indices)
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
kv_xfer_handle = self.send_kvcache( kv_xfer_handle = self.send_kvcache(
peer_name,
kv_indices,
req.dst_kv_ptrs,
chunked_dst_kv_indice,
req.dst_gpu_id,
notif,
)
handles = [kv_xfer_handle]
# Only the last chunk we need to send the aux data.
if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux(
peer_name, peer_name,
aux_index, kv_indices,
req.dst_aux_ptrs, req.dst_kv_ptrs,
req.dst_aux_index, chunked_dst_kv_indice,
str(req.room) + "_aux", req.dst_gpu_id,
notif,
) )
handles.append(aux_xfer_handle) handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux(
peer_name,
aux_index,
req.dst_aux_ptrs,
req.dst_aux_index,
str(req.room) + "_aux",
)
handles.append(aux_xfer_handle)
return handles return handles
def update_transfer_status(self): def update_transfer_status(self):
...@@ -348,7 +319,7 @@ class NixlKVManager(BaseKVManager): ...@@ -348,7 +319,7 @@ class NixlKVManager(BaseKVManager):
room = int(components[0]) room = int(components[0])
if components[1] == "kv": if components[1] == "kv":
chunk_id = int(components[2]) chunk_id = int(components[2])
is_last = bool(components[3]) is_last = bool(int(components[3]))
self.transfer_statuses[room].received_kvs.add(chunk_id) self.transfer_statuses[room].received_kvs.add(chunk_id)
if is_last: if is_last:
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1 self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
...@@ -360,34 +331,6 @@ class NixlKVManager(BaseKVManager): ...@@ -360,34 +331,6 @@ class NixlKVManager(BaseKVManager):
return False return False
return self.transfer_statuses[room].is_done() return self.transfer_statuses[room].is_done()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
"agent_name": self.agent.name,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
def _start_bootstrap_thread(self): def _start_bootstrap_thread(self):
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
...@@ -405,10 +348,19 @@ class NixlKVManager(BaseKVManager): ...@@ -405,10 +348,19 @@ class NixlKVManager(BaseKVManager):
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None": if room == "None":
continue continue
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
room = int(room) room = int(room)
with self.condition: agent_name = waiting_req_bytes[4].decode("ascii")
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) if room not in self.transfer_infos:
self.condition.notify_all() self.transfer_infos[room] = {}
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
waiting_req_bytes
)
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
if len(self.transfer_infos[room]) == required_dst_info_num:
logger.debug(f"{room=} is bootstrapped")
self.update_status(room, KVPoll.WaitingForInput)
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
...@@ -423,6 +375,9 @@ class NixlKVSender(BaseKVSender): ...@@ -423,6 +375,9 @@ class NixlKVSender(BaseKVSender):
self.xfer_handles = [] self.xfer_handles = []
self.has_sent = False self.has_sent = False
self.chunk_id = 0 self.chunk_id = 0
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
# inner state
self.curr_idx = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
...@@ -431,9 +386,11 @@ class NixlKVSender(BaseKVSender): ...@@ -431,9 +386,11 @@ class NixlKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
is_last = self.curr_idx == self.num_kv_indices
new_xfer_handles = self.kv_mgr.add_transfer_request( new_xfer_handles = self.kv_mgr.add_transfer_request(
self.bootstrap_room, self.bootstrap_room,
kv_indices, kv_indices,
...@@ -449,7 +406,7 @@ class NixlKVSender(BaseKVSender): ...@@ -449,7 +406,7 @@ class NixlKVSender(BaseKVSender):
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
if not self.has_sent: if not self.has_sent:
return KVPoll.WaitingForInput # type: ignore return self.kv_mgr.check_status(self.bootstrap_room)
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]): if all([x == "DONE" for x in states]):
return KVPoll.Success # type: ignore return KVPoll.Success # type: ignore
...@@ -461,128 +418,40 @@ class NixlKVSender(BaseKVSender): ...@@ -461,128 +418,40 @@ class NixlKVSender(BaseKVSender):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
class NixlKVReceiver(BaseKVReceiver): class NixlKVReceiver(CommonKVReceiver):
def __init__( def __init__(
self, self,
mgr: NixlKVManager, mgr: NixlKVManager,
bootstrap_addr: str, bootstrap_addr: str,
bootstrap_room: Optional[int] = None, bootstrap_room: Optional[int] = None,
): ):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.started_transfer = False self.started_transfer = False
super().__init__(mgr, bootstrap_addr, bootstrap_room)
# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank
)
if self.bootstrap_info is None:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
else:
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
assert self.bootstrap_info is not None
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
# In each dict, there are multiple possible remotes named "equal sources".
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
def _get_bootstrap_info_from_server(
self, engine_rank
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
"""Fetch the bootstrap info from the bootstrap server."""
try:
if self.kv_mgr.enable_dp_attention:
url = f"http://{self.bootstrap_addr}/route"
response = requests.get(url)
if response.status_code != 200:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
bootstrap_info = response.json()
assert isinstance(bootstrap_info, dict)
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
# split out who need to send to this rank.
# currently for dpsk mla model, those ranks share the same latent cache.
# pick one as the real source
prefill_tp_size = len(bootstrap_info.keys())
assert (
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
num_remote_tp_rank_we_managed = (
prefill_tp_size // self.kv_mgr.tp_size_of_dp
)
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
remote_tp_ranks = list(range(0, prefill_tp_size))
# split it into tp_size_of_dp parts and get our part
remote_tp_ranks_grouped = [
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
]
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
assert len(managed_ranks) == num_remote_tp_rank_we_managed
logger.debug(
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
)
return [
{
rk: bootstrap_info[rk]
for rk in bootstrap_info.keys()
if rk in managed_ranks
}
]
else:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return [{engine_rank: bootstrap_info}]
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
assert self.bootstrap_info is not None self.prefill_server_url = (
assert self.bootstrap_room is not None f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
for equal_sources in self.bootstrap_info:
remote_rank = list(equal_sources.keys())[
self.bootstrap_room % len(equal_sources)
]
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
logger.debug( logger.debug(
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}" f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
is_dummy = bootstrap_info["is_dummy"]
# TODO: just send "" for indices for dummy
if is_dummy:
# TODO: need to set success??
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
continue
# TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
...@@ -593,30 +462,22 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -593,30 +462,22 @@ class NixlKVReceiver(BaseKVReceiver):
logger.debug( logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}" f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
) )
self._connect("tcp://" + self.prefill_server_url).send_multipart( sock, lock = self._connect("tcp://" + self.prefill_server_url)
[ with lock:
GUARD, sock.send_multipart(
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
]
)
for dummy_rank in equal_sources.keys():
if dummy_rank == remote_rank:
continue
dummy_info = equal_sources[dummy_rank]
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
self._connect("tcp://" + dummy_url).send_multipart(
[ [
GUARD, GUARD,
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
str(self.required_dst_info_num).encode("ascii"),
] ]
) )
...@@ -632,152 +493,12 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -632,152 +493,12 @@ class NixlKVReceiver(BaseKVReceiver):
return KVPoll.Success # type: ignore return KVPoll.Success # type: ignore
return KVPoll.WaitingForInput # type: ignore return KVPoll.WaitingForInput # type: ignore
def _register_kv_args(self):
pass
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVReceiver Exception") raise Exception("Fake KVReceiver Exception")
class NixlKVBootstrapServer(BaseKVBootstrapServer): class NixlKVBootstrapServer(CommonKVBootstrapServer):
def __init__(self, port: int): pass
logger.debug(f"NixlKVBootstrapServer started on port {port}")
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
if request.method == "GET":
return await self._handle_metadata_get(key)
elif request.method == "PUT":
return await self._handle_metadata_put(key, request)
elif request.method == "DELETE":
return await self._handle_metadata_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_metadata_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
return web.Response(
text="metadata not found", status=404, content_type="application/json"
)
return web.Response(body=value, status=200, content_type="application/json")
async def _handle_metadata_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
return web.Response(
text="metadata updated", status=200, content_type="application/json"
)
async def _handle_metadata_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
text="metadata not found",
status=404,
content_type="application/json",
)
del self.store[key]
return web.Response(
text="metadata deleted", status=200, content_type="application/json"
)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
agent_name = data["agent_name"]
if role == "Prefill":
async with self.lock:
self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
"agent_name": agent_name,
}
logger.info(
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
if not engine_rank:
logger.debug(
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
)
# Return a dict of all engine_rank
async with self.lock:
bootstrap_info = self.prefill_port_table
return web.json_response(bootstrap_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...
...@@ -13,7 +13,7 @@ import requests ...@@ -13,7 +13,7 @@ import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.utils import get_ip from sglang.srt.utils import get_ip, get_local_ip_by_remote
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -279,3 +279,20 @@ class MetadataBuffers: ...@@ -279,3 +279,20 @@ class MetadataBuffers:
] = torch.tensor( ] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
) )
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment