"docs/source/api/vscode:/vscode.git/clone" did not exist on "8dc6784b1249dad569042b36551ecf2fce2b821d"
Unverified Commit e806f708 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[PD] Make bootstrap code common between NIXL and Mooncake (#6473)

parent fa6723f0
......@@ -47,3 +47,44 @@ $ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --dis
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --disaggregation-ib-device ${device_name} --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
```
## NIXL
### Requirements
Install via pip.
```bash
pip install nixl
```
Or build from source - may be required if you already have UCX installed.
```bash
git clone https://github.com/ai-dynamo/nixl.git
cd nixl
pip install . --config-settings=setup-args="-Ducx_path=/path/to/ucx"
```
### Usage
### Llama Single Node
```bash
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode prefill --disaggregation-transfer-backend nixl
$ python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disaggregation-mode decode --port 30001 --base-gpu-id 1 --disaggregation-transfer-backend nixl
$ python -m sglang.srt.disaggregation.mini_lb --prefill http://127.0.0.1:30000 --decode http://127.0.0.1:30001 --host 0.0.0.0 --port 8000
```
### DeepSeek Multi-Node
```bash
# prefill 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# prefill 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode prefill --host ${local_ip} --port 30000 --trust-remote-code --dist-init-addr ${prefill_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode normal --mem-fraction-static 0.8
# decode 0
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 0 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
# decode 1
$ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 ---disaggregation-transfer-backend nixl --disaggregation-mode decode --host ${local_ip} --port 30001 --trust-remote-code --dist-init-addr ${decode_master_ip}:5000 --nnodes 2 --node-rank 1 --tp-size 16 --dp-size 8 --enable-dp-attention --enable-deepep-moe --deepep-mode low_latency --mem-fraction-static 0.8 --max-running-requests 128
```
from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
from __future__ import annotations
import asyncio
import logging
import socket
import threading
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__)
class CommonKVManager(BaseKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
is_mla_backend: Optional[bool] = False,
):
self.kv_args = args
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.enable_dp_attention = server_args.enable_dp_attention
if not server_args.enable_dp_attention and server_args.dp_size != 1:
raise ValueError(
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
)
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
self.prefill_dp_size_table: Dict[str, int] = {}
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"tp_size": self.tp_size,
"dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
class CommonKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__(
self,
mgr: BaseKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
)
else:
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
self.prefill_dp_size
)
else:
self.prefill_tp_size = self.kv_mgr.prefill_tp_size_table[
self.bootstrap_addr
]
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
self.bootstrap_addr
]
# Currently, we don't allow prefill instance and decode instance to
# have different TP sizes per DP rank, except for models using MLA.
local_tp_size_per_dp_rank = self.kv_mgr.tp_size // self.kv_mgr.dp_size
prefill_tp_size_per_dp_rank = self.prefill_tp_size // self.prefill_dp_size
if local_tp_size_per_dp_rank == prefill_tp_size_per_dp_rank:
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
)
self.required_dst_info_num = 1
self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
)
self.target_tp_ranks = [self.target_tp_rank]
else:
assert (
self.kv_mgr.is_mla_backend
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [
rank
for rank in range(
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
(self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank + 1)
* (prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank),
)
]
# For MLA models, we can retrieve KVCache from only one prefill rank, but we still need to maintain
# multiple connections in the connection pool and have to send dummy requests to other prefill ranks,
# or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
)
if bootstrap_key not in self.kv_mgr.connection_pool:
bootstrap_infos = []
for target_tp_rank in self.target_tp_ranks:
bootstrap_info = self._get_bootstrap_info_from_server(
target_tp_rank,
self.target_dp_group,
)
if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
)
self.bootstrap_infos = bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
response = requests.get(url)
if response.status_code == 200:
prefill_parallel_info = response.json()
return int(prefill_parallel_info["prefill_tp_size"]), int(
prefill_parallel_info["prefill_dp_size"]
)
else:
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
@classmethod
def _connect(cls, endpoint: str):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
def _register_kv_args(self):
pass
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class CommonKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.tp_size = None
self.dp_size = None
self.tp_size_per_dp_rank = None
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
tp_size = data["tp_size"]
dp_size = data["dp_size"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
if self.tp_size is None:
self.tp_size = tp_size
if self.dp_size is None:
self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
target_dp_group = request.query.get("target_dp_group")
if not engine_rank or not target_dp_group:
return web.Response(text="Missing inputs for bootstrap server.", status=400)
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
if int(engine_rank) == -1 and int(target_dp_group) == -1:
prefill_parallel_info = {
"prefill_tp_size": self.tp_size,
"prefill_dp_size": self.dp_size,
}
return web.json_response(prefill_parallel_info, status=200)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
int(engine_rank)
]
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Bootstrap info not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...
......@@ -29,7 +29,10 @@ from sglang.srt.disaggregation.base.conn import (
KVPoll,
)
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
group_concurrent_contiguous,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_free_port,
......@@ -41,23 +44,6 @@ from sglang.srt.utils import (
logger = logging.getLogger(__name__)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
class KVTransferError(Exception):
def __init__(self, bootstrap_room: int, failure_reason: str):
super().__init__(failure_reason)
......
......@@ -13,7 +13,7 @@ import requests
import torch
import torch.distributed as dist
from sglang.srt.utils import get_ip
from sglang.srt.utils import get_ip, get_local_ip_by_remote
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -279,3 +279,20 @@ class MetadataBuffers:
] = torch.tensor(
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment