Unverified Commit 78689d33 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

PD Rust LB (PO2) (#6437)


Co-authored-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
parent 1dc6864f
import argparse
import dataclasses
@dataclasses.dataclass
class LBArgs:
rust_lb: bool = False
host: str = "0.0.0.0"
port: int = 8000
policy: str = "random"
prefill_infos: list = dataclasses.field(default_factory=list)
decode_infos: list = dataclasses.field(default_factory=list)
log_interval: int = 5
timeout: int = 600
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--rust-lb",
action="store_true",
help="Use Rust load balancer",
)
parser.add_argument(
"--host",
type=str,
default=LBArgs.host,
help=f"Host to bind the server (default: {LBArgs.host})",
)
parser.add_argument(
"--port",
type=int,
default=LBArgs.port,
help=f"Port to bind the server (default: {LBArgs.port})",
)
parser.add_argument(
"--policy",
type=str,
default=LBArgs.policy,
choices=["random", "po2"],
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
)
parser.add_argument(
"--prefill",
type=str,
default=[],
nargs="+",
help="URLs for prefill servers",
)
parser.add_argument(
"--decode",
type=str,
default=[],
nargs="+",
help="URLs for decode servers",
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--log-interval",
type=int,
default=LBArgs.log_interval,
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
)
parser.add_argument(
"--timeout",
type=int,
default=LBArgs.timeout,
help=f"Timeout in seconds (default: {LBArgs.timeout})",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_infos = [
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
return cls(
rust_lb=args.rust_lb,
host=args.host,
port=args.port,
policy=args.policy,
prefill_infos=prefill_infos,
decode_infos=args.decode,
log_interval=args.log_interval,
timeout=args.timeout,
)
def __post_init__(self):
if not self.rust_lb:
assert (
self.policy == "random"
), "Only random policy is supported for Python load balancer"
def main():
parser = argparse.ArgumentParser(
description="PD Disaggregation Load Balancer Server"
)
LBArgs.add_cli_args(parser)
args = parser.parse_args()
lb_args = LBArgs.from_cli_args(args)
if lb_args.rust_lb:
from sgl_pdlb._rust import LoadBalancer as RustLB
RustLB(
host=lb_args.host,
port=lb_args.port,
policy=lb_args.policy,
prefill_infos=lb_args.prefill_infos,
decode_infos=lb_args.decode_infos,
log_interval=lb_args.log_interval,
timeout=lb_args.timeout,
).start()
else:
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
prefill_configs = [
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
]
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
if __name__ == "__main__":
main()
......@@ -377,42 +377,7 @@ def run(prefill_configs, decode_addrs, host, port):
if __name__ == "__main__":
import argparse
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
from sglang.srt.disaggregation.launch_lb import main
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
parser.add_argument(
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
)
parser.add_argument(
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
)
parser.add_argument(
"--prefill-bootstrap-ports",
type=int,
nargs="+",
help="Bootstrap ports for prefill servers",
)
parser.add_argument(
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
)
args = parser.parse_args()
bootstrap_ports = args.prefill_bootstrap_ports
if bootstrap_ports is None:
bootstrap_ports = [None] * len(args.prefill)
elif len(bootstrap_ports) == 1:
bootstrap_ports = bootstrap_ports * len(args.prefill)
else:
if len(bootstrap_ports) != len(args.prefill):
raise ValueError(
"Number of prefill URLs must match number of bootstrap ports"
)
prefill_configs = [
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
]
run(prefill_configs, args.decode, args.host, args.port)
main()
......@@ -229,6 +229,11 @@ async def get_server_info():
}
@app.get("/get_load")
async def get_load():
return await _global_state.tokenizer_manager.get_load()
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
......
......@@ -103,7 +103,7 @@ class GenerateReqInput:
# For disaggregated inference
bootstrap_host: Optional[Union[List[str], str]] = None
bootstrap_port: Optional[Union[List[int], int]] = None
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
bootstrap_room: Optional[Union[List[int], int]] = None
def contains_mm_input(self) -> bool:
......
......@@ -1911,6 +1911,27 @@ class Scheduler(
if_success = False
return if_success
def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
len(req.origin_input_ids)
for req in self.disagg_prefill_bootstrap_queue.queue
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
load += sum(
len(req.req.origin_input_ids)
for req in self.disagg_decode_prealloc_queue.queue
)
return load
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret["last_gen_throughput"] = self.last_gen_throughput
......@@ -1920,9 +1941,10 @@ class Scheduler(
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
internal_state=ret,
)
ret["load"] = self.get_load()
return GetInternalStateReqOutput(internal_state=ret)
def set_internal_state(self, recv_req: SetInternalStateReq):
server_args_dict = recv_req.server_args
......
......@@ -395,6 +395,9 @@ class TokenizerManager:
self.server_args.disaggregation_bootstrap_port
)
self.current_load = 0
self.current_load_lock = asyncio.Lock()
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -983,6 +986,14 @@ class TokenizerManager:
# Many DP ranks
return [res.internal_state for res in responses]
async def get_load(self) -> dict:
# TODO(lsyin): fake load report server
if not self.current_load_lock.locked():
async with self.current_load_lock:
internal_state = await self.get_internal_state()
self.current_load = internal_state[0]["load"]
return {"load": self.current_load}
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
......
reorder_imports = true
reorder_modules = true
[package]
edition = "2024"
name = "sgl-pdlb"
version = "0.1.0"
[lib]
crate-type = ["cdylib", "rlib"]
name = "sgl_pdlb_rs"
[dependencies]
actix-web = "4.11"
bytes = "1.8.0"
chrono = "0.4.38"
clap = { version = "4.4", features = ["derive"] }
dashmap = "6.1.0"
env_logger = "0.11.5"
futures = "0.3"
futures-util = "0.3"
http = "1.3.1"
log = "0.4.22"
pyo3 = { version = "0.25.0", features = ["extension-module"] }
rand = "0.9.0"
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1.34", features = ["full"] }
anyhow = "1.0.98"
typetag = "0.2.20"
### Install dependencies
```bash
pip install "maturin[patchelf]"
```
### Build and install
```bash
maturin develop
pip install -e .
```
[build-system]
requires = ["maturin>=1.8.0"]
build-backend = "maturin"
[project]
name = "sgl_pdlb"
version = "0.0.1"
[tool.maturin]
python-source = "py_src"
module-name = "sgl_pdlb._rust"
[tool.maturin.build-backend]
features = ["pyo3/extension-module"]
use crate::strategy_lb::EngineInfo;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
#[typetag::serde(tag = "type")]
pub trait Bootstrap {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> {
let batch_size = self.get_batch_size()?;
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::<u64>()).collect()),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapRoom::Single(rand::random::<u64>()),
);
}
Ok(())
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
if self.text.is_some() && self.input_ids.is_some() {
return Err(actix_web::error::ErrorBadRequest(
"Both text and input_ids are present in the request".to_string(),
));
}
if let Some(InputText::Batch(texts)) = &self.text {
return Ok(Some(texts.len()));
}
if let Some(InputIds::Batch(ids)) = &self.input_ids {
return Ok(Some(ids.len()));
}
Ok(None)
}
}
#[typetag::serde]
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
#[typetag::serde]
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
use crate::io_struct::Bootstrap;
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
use actix_web::HttpResponse;
use bytes::Bytes;
use futures::{Stream, StreamExt, future::join_all};
use reqwest::{Method, StatusCode};
use std::pin::Pin;
pub enum ProxyResponseBody {
Full(Bytes),
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
}
pub struct ProxyResponse {
pub status: StatusCode,
pub body: ProxyResponseBody,
}
impl ProxyResponse {
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
match &self.body {
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
"Stream response is not supported",
)),
}
}
}
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
fn into(self) -> Result<HttpResponse, actix_web::Error> {
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
})?;
match self.body {
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
.status(status)
.content_type("application/octet-stream")
.streaming(body)),
}
}
}
#[derive(Debug, Clone)]
pub struct LBConfig {
pub host: String,
pub port: u16,
pub policy: String,
pub prefill_infos: Vec<(String, Option<u16>)>,
pub decode_infos: Vec<String>,
pub log_interval: u64,
pub timeout: u64,
}
#[derive(Debug, Clone)]
pub struct LBState {
pub strategy_lb: StrategyLB,
pub client: reqwest::Client,
pub log_interval: u64,
}
impl LBState {
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(lb_config.timeout))
.build()?;
let policy = match lb_config.policy.as_str() {
"random" => LBPolicy::Random,
"po2" => LBPolicy::PowerOfTwo,
_ => anyhow::bail!("Invalid policy"),
};
let prefill_servers = lb_config
.prefill_infos
.into_iter()
.map(|(url, port)| EngineInfo::new_prefill(url, port))
.collect();
let decode_servers = lb_config
.decode_infos
.into_iter()
.map(|url| EngineInfo::new_decode(url))
.collect();
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
Ok(Self {
strategy_lb: lb,
client,
log_interval: lb_config.log_interval,
})
}
pub async fn route_one(
&self,
engine_info: &EngineInfo,
method: Method,
api_path: &str,
request: Option<&serde_json::Value>,
stream: bool,
) -> Result<ProxyResponse, actix_web::Error> {
let url = engine_info.api_path(api_path);
let request = request.unwrap_or(&serde_json::Value::Null);
let task = self.client.request(method, url).json(request).send();
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
// FIXME: handle error status code (map status code to error)
let status = resp.status();
let body = if stream {
let resp_stream = resp.bytes_stream().map(|r| {
r.map_err(actix_web::error::ErrorBadGateway)
.map(Bytes::from)
});
ProxyResponseBody::Stream(Box::pin(resp_stream))
} else {
let body = resp
.bytes()
.await
.map_err(actix_web::error::ErrorBadGateway)?;
ProxyResponseBody::Full(body)
};
Ok(ProxyResponse { status, body })
}
pub async fn route_collect(
&self,
engines: &Vec<EngineInfo>,
method: Method,
api_path: &str,
request: Option<&serde_json::Value>,
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
let tasks = engines
.iter()
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
let responses = join_all(tasks).await;
responses
.into_iter()
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
.collect()
}
pub async fn generate(
&self,
api_path: &str,
mut req: Box<dyn Bootstrap>,
) -> Result<HttpResponse, actix_web::Error> {
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
let stream = req.is_stream();
req.add_bootstrap_info(&prefill)?;
let json = serde_json::to_value(req)?;
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
decode_response?.into()
}
pub async fn get_engine_loads(
&self,
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
let servers = self.strategy_lb.get_all_servers();
let responses = self
.route_collect(&servers, Method::GET, "/get_load", None)
.await?;
let loads = responses
.into_iter()
.enumerate()
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
for load in loads {
match load.engine_info.engine_type {
EngineType::Prefill => prefill_loads.push(load),
EngineType::Decode => decode_loads.push(load),
}
}
Ok((prefill_loads, decode_loads))
}
}
pub mod io_struct;
pub mod lb_state;
pub mod server;
pub mod strategy_lb;
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use lb_state::{LBConfig, LBState};
use server::{periodic_logging, startup};
use tokio::signal;
#[pyclass]
pub struct LoadBalancer {
lb_config: LBConfig,
}
#[pymethods]
impl LoadBalancer {
#[new]
pub fn new(
host: String,
port: u16,
policy: String,
prefill_infos: Vec<(String, Option<u16>)>,
decode_infos: Vec<String>,
log_interval: u64,
timeout: u64,
) -> PyResult<Self> {
let lb_config = LBConfig {
host,
port,
policy,
prefill_infos,
decode_infos,
log_interval,
timeout,
};
Ok(LoadBalancer { lb_config })
}
pub fn start(&self) -> PyResult<()> {
let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| {
PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e))
})?;
let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move {
tokio::select! {
_ = periodic_logging(lb_state.clone()) => {
unreachable!()
}
res = startup(self.lb_config.clone(), lb_state) => {
res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
unreachable!()
}
_ = signal::ctrl_c() => {
println!("Received Ctrl+C, shutting down");
std::process::exit(0);
}
}
});
ret
}
}
#[pymodule]
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<LoadBalancer>()?;
Ok(())
}
mod io_struct;
mod lb_state;
mod server;
mod strategy_lb;
use lb_state::{LBConfig, LBState};
use server::{periodic_logging, startup};
use tokio::signal;
fn main() -> anyhow::Result<()> {
// FIXME: test code, move to test folder
let prefill_infos = (0..8)
.map(|i| (format!("123.123.123.123:{}", i), None))
.collect::<Vec<(String, Option<u16>)>>();
let decode_infos = (0..32)
.map(|i| format!("233.233.233.233:{}", i))
.collect::<Vec<String>>();
let lb_config = LBConfig {
host: "localhost".to_string(),
port: 8080,
policy: "random".to_string(),
prefill_infos,
decode_infos,
log_interval: 5,
timeout: 600,
};
let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?;
let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move {
tokio::select! {
_ = periodic_logging(lb_state.clone()) => {
unreachable!()
}
res = startup(lb_config.clone(), lb_state) => {
res.map_err(|e| anyhow::anyhow!(e))?;
unreachable!()
}
_ = signal::ctrl_c() => {
println!("Received Ctrl+C, shutting down");
std::process::exit(0);
}
}
});
ret
}
use crate::io_struct::{ChatReqInput, GenerateReqInput};
use crate::lb_state::{LBConfig, LBState};
use crate::strategy_lb::EngineType;
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
use reqwest::Method;
use serde_json::json;
use std::io::Write;
#[get("/health")]
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
HttpResponse::Ok().body("Ok")
}
#[get("/health_generate")]
pub async fn health_generate(
_req: HttpRequest,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
let servers = app_state.strategy_lb.get_all_servers();
app_state
.route_collect(&servers, Method::GET, "/health_generate", None)
.await?;
// FIXME: log the response
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
}
#[post("/flush_cache")]
pub async fn flush_cache(
_req: HttpRequest,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
let servers = app_state.strategy_lb.get_all_servers();
app_state
.route_collect(&servers, Method::POST, "/flush_cache", None)
.await?;
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
}
#[get("/get_model_info")]
pub async fn get_model_info(
_req: HttpRequest,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
// Return the first server's model info
let engine = app_state.strategy_lb.get_one_server();
app_state
.route_one(&engine, Method::GET, "/get_model_info", None, false)
.await?
.into()
}
#[post("/generate")]
pub async fn generate(
_req: HttpRequest,
req: web::Json<GenerateReqInput>,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
app_state
.generate("/generate", Box::new(req.into_inner()))
.await
}
#[post("/v1/chat/completions")]
pub async fn chat_completions(
_req: HttpRequest,
req: web::Json<ChatReqInput>,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
app_state
.generate("/v1/chat/completions", Box::new(req.into_inner()))
.await
}
#[get("/get_server_info")]
pub async fn get_server_info(
_req: HttpRequest,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
let servers = app_state.strategy_lb.get_all_servers();
let responses = app_state
.route_collect(&servers, Method::GET, "/get_server_info", None)
.await?;
let mut prefill_infos = Vec::new();
let mut decode_infos = Vec::new();
for (i, resp) in responses.iter().enumerate() {
let json = resp.to_json()?;
match servers[i].engine_type {
EngineType::Prefill => prefill_infos.push(json),
EngineType::Decode => decode_infos.push(json),
}
}
Ok(HttpResponse::Ok().json(json!({
"prefill": prefill_infos,
"decode": decode_infos,
})))
}
#[get("/get_loads")]
pub async fn get_loads(
_req: HttpRequest,
app_state: web::Data<LBState>,
) -> Result<HttpResponse, actix_web::Error> {
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
Ok(HttpResponse::Ok().json(json!({
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
})))
}
pub async fn periodic_logging(lb_state: LBState) {
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
loop {
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
Err(e) => {
log::error!("Failed to get engine loads: {}", e);
continue;
}
};
let prefill_loads = prefill_loads
.into_iter()
.map(|l| l.to_string())
.collect::<Vec<_>>();
let decode_loads = decode_loads
.into_iter()
.map(|l| l.to_string())
.collect::<Vec<_>>();
log::info!("Prefill loads: {}", prefill_loads.join(", "));
log::info!("Decode loads: {}", decode_loads.join(", "));
}
}
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
let app_state = web::Data::new(lb_state);
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
// default level is info
env_logger::Builder::new()
.format(|buf, record| {
writeln!(
buf,
"{} - {} - {}",
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
record.level(),
record.args()
)
})
.filter(None, log::LevelFilter::Info)
.init();
HttpServer::new(move || {
actix_web::App::new()
.wrap(actix_web::middleware::Logger::default())
.app_data(app_state.clone())
.service(health)
.service(health_generate)
.service(flush_cache)
.service(get_model_info)
.service(get_server_info)
.service(get_loads)
.service(generate)
.service(chat_completions)
})
.bind((lb_config.host, lb_config.port))?
.run()
.await?;
std::io::Result::Ok(())
}
use rand::Rng;
use serde_json::json;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn to_string(&self) -> String {
format!("({:?}@{})", self.engine_type, self.url)
}
pub fn get_hostname(&self) -> String {
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap().to_string()
}
}
pub struct EngineLoad {
pub engine_info: EngineInfo,
pub load: isize,
}
impl EngineLoad {
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
let load = match json.get("load") {
Some(load) => load.as_i64().unwrap_or(-1) as isize,
None => -1,
};
EngineLoad {
engine_info: engine_info.clone(),
load,
}
}
pub fn to_json(&self) -> serde_json::Value {
json!({
"engine": self.engine_info.to_string(),
"load": self.load,
})
}
pub fn to_string(&self) -> String {
format!("{}: {}", self.engine_info.to_string(), self.load)
}
}
#[derive(Debug, Clone)]
pub enum LBPolicy {
Random,
PowerOfTwo,
}
#[derive(Debug, Clone)]
pub struct StrategyLB {
pub policy: LBPolicy,
pub prefill_servers: Vec<EngineInfo>,
pub decode_servers: Vec<EngineInfo>,
}
impl StrategyLB {
pub fn new(
policy: LBPolicy,
prefill_servers: Vec<EngineInfo>,
decode_servers: Vec<EngineInfo>,
) -> Self {
StrategyLB {
policy,
prefill_servers,
decode_servers,
}
}
pub fn get_one_server(&self) -> EngineInfo {
assert!(!self.prefill_servers.is_empty());
assert!(!self.decode_servers.is_empty());
self.prefill_servers[0].clone()
}
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
let mut all_servers = Vec::new();
all_servers.extend(self.prefill_servers.clone());
all_servers.extend(self.decode_servers.clone());
all_servers
}
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
match self.policy {
LBPolicy::Random => self.select_pd_pair_random(),
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
}
}
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
let mut rng = rand::rng();
let prefill_index = rng.random_range(0..self.prefill_servers.len());
let decode_index = rng.random_range(0..self.decode_servers.len());
(
self.prefill_servers[prefill_index].clone(),
self.decode_servers[decode_index].clone(),
)
}
async fn get_load_from_engine(
&self,
client: &reqwest::Client,
engine_info: &EngineInfo,
) -> Option<isize> {
let url = engine_info.api_path("/get_load");
let response = client.get(url).send().await.unwrap();
match response.status() {
reqwest::StatusCode::OK => {
let data = response.json::<serde_json::Value>().await.unwrap();
Some(data["load"].as_i64().unwrap() as isize)
}
_ => None,
}
}
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
let mut rng = rand::rng();
let prefill1 =
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
let prefill2 =
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
let decode1_load = self.get_load_from_engine(client, &decode1).await;
let decode2_load = self.get_load_from_engine(client, &decode2).await;
(
if prefill1_load < prefill2_load {
prefill1
} else {
prefill2
},
if decode1_load < decode2_load {
decode1
} else {
decode2
},
)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment