".github/vscode:/vscode.git/clone" did not exist on "1f766c36fb61f7b1969664645bf38dae93f568a2"
Unverified Commit a3e4e9bf authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Better PD initialization (#5751)

parent 6d4d3bc8
...@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing. ...@@ -3,10 +3,12 @@ Minimal HTTP load balancer for prefill and decode servers for testing.
""" """
import asyncio import asyncio
import dataclasses
import logging
import random import random
import urllib import urllib
from itertools import chain from itertools import chain
from typing import List from typing import List, Optional
import aiohttp import aiohttp
import orjson import orjson
...@@ -14,11 +16,32 @@ import uvicorn ...@@ -14,11 +16,32 @@ import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import PDRegistryRequest
def setup_logger():
logger = logging.getLogger("pdlb")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
@dataclasses.dataclass
class PrefillConfig: class PrefillConfig:
def __init__(self, url: str, bootstrap_port: int): url: str
self.url = url bootstrap_port: Optional[int] = None
self.bootstrap_port = bootstrap_port
class MiniLoadBalancer: class MiniLoadBalancer:
...@@ -28,6 +51,10 @@ class MiniLoadBalancer: ...@@ -28,6 +51,10 @@ class MiniLoadBalancer:
self.decode_servers = decode_servers self.decode_servers = decode_servers
def select_pair(self): def select_pair(self):
# TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available"
assert len(self.decode_servers) > 0, "No decode servers available"
prefill_config = random.choice(self.prefill_configs) prefill_config = random.choice(self.prefill_configs)
decode_server = random.choice(self.decode_servers) decode_server = random.choice(self.decode_servers)
return prefill_config.url, prefill_config.bootstrap_port, decode_server return prefill_config.url, prefill_config.bootstrap_port, decode_server
...@@ -47,7 +74,7 @@ class MiniLoadBalancer: ...@@ -47,7 +74,7 @@ class MiniLoadBalancer:
session.post(f"{decode_server}/{endpoint}", json=modified_request), session.post(f"{decode_server}/{endpoint}", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first. # Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks) _, decode_response = await asyncio.gather(*tasks)
return ORJSONResponse( return ORJSONResponse(
content=await decode_response.json(), content=await decode_response.json(),
...@@ -268,6 +295,32 @@ async def get_models(): ...@@ -268,6 +295,32 @@ async def get_models():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.post("/register")
async def register(obj: PDRegistryRequest):
if obj.mode == "prefill":
load_balancer.prefill_configs.append(
PrefillConfig(obj.registry_url, obj.bootstrap_port)
)
logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
)
elif obj.mode == "decode":
load_balancer.decode_servers.append(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}")
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Must be either PREFILL or DECODE.",
)
logger.info(
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
f"#Decode servers: {len(load_balancer.decode_servers)}"
)
return Response(status_code=200)
def run(prefill_configs, decode_addrs, host, port): def run(prefill_configs, decode_addrs, host, port):
global load_balancer global load_balancer
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs) load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
...@@ -279,15 +332,16 @@ if __name__ == "__main__": ...@@ -279,15 +332,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Mini Load Balancer Server") parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument( parser.add_argument(
"--prefill", required=True, help="Comma-separated URLs for prefill servers" "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
) )
parser.add_argument( parser.add_argument(
"--prefill-bootstrap-ports", "--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
help="Comma-separated bootstrap ports for prefill servers",
default="8998",
) )
parser.add_argument( parser.add_argument(
"--decode", required=True, help="Comma-separated URLs for decode servers" "--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
) )
parser.add_argument( parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)" "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
...@@ -297,22 +351,19 @@ if __name__ == "__main__": ...@@ -297,22 +351,19 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
prefill_urls = args.prefill.split(",") bootstrap_ports = args.prefill_bootstrap_ports
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")] if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
if len(bootstrap_ports) == 1: elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(prefill_urls) bootstrap_ports = bootstrap_ports * len(args.prefill)
else: else:
if len(bootstrap_ports) != len(prefill_urls): if len(bootstrap_ports) != len(args.prefill):
raise ValueError( raise ValueError(
"Number of prefill URLs must match number of bootstrap ports" "Number of prefill URLs must match number of bootstrap ports"
) )
exit(1)
prefill_configs = []
for url, port in zip(prefill_urls, bootstrap_ports):
prefill_configs.append(PrefillConfig(url, port))
decode_addrs = args.decode.split(",") prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
run(prefill_configs, decode_addrs, args.host, args.port) run(prefill_configs, args.decode, args.host, args.port)
from __future__ import annotations from __future__ import annotations
import dataclasses
import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
from typing import List from typing import List, Optional
import numpy as np import numpy as np
import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.utils import get_ip
class DisaggregationMode(Enum): class DisaggregationMode(Enum):
NULL = "null" NULL = "null"
...@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): ...@@ -119,3 +124,41 @@ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
def kv_to_page_num(num_kv_indices: int, page_size: int): def kv_to_page_num(num_kv_indices: int, page_size: int):
# ceil(num_kv_indices / page_size) # ceil(num_kv_indices / page_size)
return (num_kv_indices + page_size - 1) // page_size return (num_kv_indices + page_size - 1) // page_size
@dataclasses.dataclass
class PDRegistryRequest:
"""A request to register a machine itself to the LB."""
mode: str
registry_url: str
bootstrap_port: Optional[int] = None
def __post_init__(self):
if self.mode == "prefill" and self.bootstrap_port is None:
raise ValueError("Bootstrap port must be set in PREFILL mode.")
elif self.mode == "decode" and self.bootstrap_port is not None:
raise ValueError("Bootstrap port must not be set in DECODE mode.")
elif self.mode not in ["prefill", "decode"]:
raise ValueError(
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
)
def register_disaggregation_server(
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
):
boostrap_port = bootstrap_port if mode == "prefill" else None
registry_request = PDRegistryRequest(
mode=mode,
registry_url=f"http://{get_ip()}:{server_port}",
bootstrap_port=boostrap_port,
)
res = requests.post(
f"{pdlb_url}/register",
json=dataclasses.asdict(registry_request),
)
if res.status_code != 200:
warnings.warn(
f"Failed to register disaggregation server: {res.status_code} {res.text}"
)
...@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile ...@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import FakeBootstrapHost from sglang.srt.disaggregation.utils import (
FakeBootstrapHost,
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -871,5 +874,13 @@ def _wait_and_warmup( ...@@ -871,5 +874,13 @@ def _wait_and_warmup(
if server_args.debug_tensor_dump_input_file: if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid()) kill_process_tree(os.getpid())
if server_args.pdlb_url is not None:
register_disaggregation_server(
server_args.disaggregation_mode,
server_args.port,
server_args.disaggregation_bootstrap_port,
server_args.pdlb_url,
)
if launch_callback is not None: if launch_callback is not None:
launch_callback() launch_callback()
...@@ -925,6 +925,10 @@ class Scheduler( ...@@ -925,6 +925,10 @@ class Scheduler(
) )
custom_logit_processor = None custom_logit_processor = None
if recv_req.bootstrap_port is None:
# Use default bootstrap port
recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
req = Req( req = Req(
recv_req.rid, recv_req.rid,
recv_req.input_text, recv_req.input_text,
......
...@@ -198,6 +198,7 @@ class ServerArgs: ...@@ -198,6 +198,7 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake" disaggregation_transfer_backend: str = "mooncake"
disaggregation_ib_device: Optional[str] = None disaggregation_ib_device: Optional[str] = None
pdlb_url: Optional[str] = None
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
...@@ -1254,6 +1255,12 @@ class ServerArgs: ...@@ -1254,6 +1255,12 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.", "Default is None, which triggers automatic device detection when mooncake backend is enabled.",
) )
parser.add_argument(
"--pdlb-url",
type=str,
default=None,
help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment