Unverified Commit 4d62bca5 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Replace print with logger (#2183)

parent e1e595d7
...@@ -237,6 +237,21 @@ dependencies = [ ...@@ -237,6 +237,21 @@ dependencies = [
"alloc-no-stdlib", "alloc-no-stdlib",
] ]
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]]
name = "android_system_properties"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "anstream" name = "anstream"
version = "0.6.18" version = "0.6.18"
...@@ -411,6 +426,20 @@ version = "1.0.0" ...@@ -411,6 +426,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.38"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.20" version = "4.5.20"
...@@ -721,6 +750,29 @@ dependencies = [ ...@@ -721,6 +750,29 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "env_filter"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab"
dependencies = [
"log",
"regex",
]
[[package]]
name = "env_logger"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
...@@ -1016,6 +1068,12 @@ version = "1.0.3" ...@@ -1016,6 +1068,12 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.5.0" version = "1.5.0"
...@@ -1088,6 +1146,29 @@ dependencies = [ ...@@ -1088,6 +1146,29 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "iana-time-zone"
version = "0.1.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "icu_collections" name = "icu_collections"
version = "1.5.0" version = "1.5.0"
...@@ -1523,6 +1604,15 @@ version = "0.1.0" ...@@ -1523,6 +1604,15 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "number_prefix" name = "number_prefix"
version = "0.4.0" version = "0.4.0"
...@@ -2116,10 +2206,13 @@ version = "0.0.0" ...@@ -2116,10 +2206,13 @@ version = "0.0.0"
dependencies = [ dependencies = [
"actix-web", "actix-web",
"bytes", "bytes",
"chrono",
"clap", "clap",
"dashmap", "dashmap",
"env_logger",
"futures-util", "futures-util",
"http 1.1.0", "http 1.1.0",
"log",
"pyo3", "pyo3",
"rand", "rand",
"reqwest", "reqwest",
...@@ -2688,6 +2781,15 @@ dependencies = [ ...@@ -2688,6 +2781,15 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
] ]
[[package]]
name = "windows-core"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.6",
]
[[package]] [[package]]
name = "windows-registry" name = "windows-registry"
version = "0.2.0" version = "0.2.0"
......
...@@ -26,6 +26,9 @@ pyo3 = { version = "0.22.5", features = ["extension-module"] } ...@@ -26,6 +26,9 @@ pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] } tokenizers = { version = "0.20.3", features = ["http"] }
dashmap = "6.1.0" dashmap = "6.1.0"
http = "1.1.0" http = "1.1.0"
env_logger = "0.11.5"
log = "0.4.22"
chrono = "0.4.38"
[profile.release] [profile.release]
lto = "thin" lto = "thin"
......
import argparse import argparse
import dataclasses import dataclasses
import logging
import sys import sys
from typing import List, Optional from typing import List, Optional
...@@ -7,6 +8,22 @@ from sglang_router import Router ...@@ -7,6 +8,22 @@ from sglang_router import Router
from sglang_router_rs import PolicyType from sglang_router_rs import PolicyType
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
@dataclasses.dataclass @dataclasses.dataclass
class RouterArgs: class RouterArgs:
# Worker configuration # Worker configuration
...@@ -21,6 +38,7 @@ class RouterArgs: ...@@ -21,6 +38,7 @@ class RouterArgs:
balance_rel_threshold: float = 1.0001 balance_rel_threshold: float = 1.0001
eviction_interval: int = 60 eviction_interval: int = 60
max_tree_size: int = 2**24 max_tree_size: int = 2**24
verbose: bool = False
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(
...@@ -98,6 +116,11 @@ class RouterArgs: ...@@ -98,6 +116,11 @@ class RouterArgs:
default=RouterArgs.max_tree_size, default=RouterArgs.max_tree_size,
help="Maximum size of the approximation tree for cache-aware routing", help="Maximum size of the approximation tree for cache-aware routing",
) )
parser.add_argument(
f"--{prefix}verbose",
action="store_true",
help="Enable verbose logging",
)
@classmethod @classmethod
def from_cli_args( def from_cli_args(
...@@ -121,6 +144,7 @@ class RouterArgs: ...@@ -121,6 +144,7 @@ class RouterArgs:
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
eviction_interval=getattr(args, f"{prefix}eviction_interval"), eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"), max_tree_size=getattr(args, f"{prefix}max_tree_size"),
verbose=getattr(args, f"{prefix}verbose", False),
) )
...@@ -145,6 +169,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -145,6 +169,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
Returns: Returns:
Router instance if successful, None if failed Router instance if successful, None if failed
""" """
logger = logging.getLogger("router")
try: try:
# Convert to RouterArgs if needed # Convert to RouterArgs if needed
if not isinstance(args, RouterArgs): if not isinstance(args, RouterArgs):
...@@ -162,13 +187,14 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: ...@@ -162,13 +187,14 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
balance_rel_threshold=router_args.balance_rel_threshold, balance_rel_threshold=router_args.balance_rel_threshold,
eviction_interval_secs=router_args.eviction_interval, eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size, max_tree_size=router_args.max_tree_size,
verbose=router_args.verbose,
) )
router.start() router.start()
return router return router
except Exception as e: except Exception as e:
print(f"Error starting router: {e}", file=sys.stderr) logger.error(f"Error starting router: {e}", file=sys.stderr)
return None return None
...@@ -202,6 +228,7 @@ Examples: ...@@ -202,6 +228,7 @@ Examples:
def main() -> None: def main() -> None:
logger = setup_logger()
router_args = parse_router_args(sys.argv[1:]) router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args) router = launch_router(router_args)
......
import argparse import argparse
import copy import copy
import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import random import random
...@@ -17,6 +18,22 @@ from sglang.srt.utils import is_port_available ...@@ -17,6 +18,22 @@ from sglang.srt.utils import is_port_available
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Create new process group # Create new process group
def run_server(server_args, dp_rank): def run_server(server_args, dp_rank):
os.setpgrp() # Create new process group os.setpgrp() # Create new process group
...@@ -42,20 +59,20 @@ def launch_server_process( ...@@ -42,20 +59,20 @@ def launch_server_process(
def cleanup_processes(processes: List[mp.Process]): def cleanup_processes(processes: List[mp.Process]):
"""Clean up all processes using process groups.""" logger = logging.getLogger("router")
print("\nCleaning up processes...") logger.info("Cleaning up processes...")
for proc in processes: for proc in processes:
if proc.is_alive(): if proc.is_alive():
try: try:
# Kill the entire process group
os.killpg(os.getpgid(proc.pid), signal.SIGTERM) os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
# Give processes some time to terminate gracefully
proc.join(timeout=3) proc.join(timeout=3)
# If process is still alive, force kill
if proc.is_alive(): if proc.is_alive():
logger.warning(
f"Process {proc.pid} did not terminate gracefully, force killing..."
)
os.killpg(os.getpgid(proc.pid), signal.SIGKILL) os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except ProcessLookupError: except ProcessLookupError:
pass # Process already terminated pass
def setup_signal_handlers(cleanup_func): def setup_signal_handlers(cleanup_func):
...@@ -101,6 +118,8 @@ def find_available_ports(base_port: int, count: int) -> List[int]: ...@@ -101,6 +118,8 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def main(): def main():
logger = setup_logger()
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn") mp.set_start_method("spawn")
...@@ -130,8 +149,8 @@ def main(): ...@@ -130,8 +149,8 @@ def main():
server_processes = [] server_processes = []
try: try:
# Launch server processes
for i, worker_port in enumerate(worker_ports): for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i) proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc) server_processes.append(proc)
...@@ -140,18 +159,19 @@ def main(): ...@@ -140,18 +159,19 @@ def main():
# Wait for all servers to be healthy # Wait for all servers to be healthy
all_healthy = True all_healthy = True
for port in worker_ports: for port in worker_ports:
if not wait_for_server_health(server_args.host, port): if not wait_for_server_health(server_args.host, port):
print(f"Server on port {port} failed to become healthy") logger.error(f"Server on port {port} failed to become healthy")
all_healthy = False all_healthy = False
break break
if not all_healthy: if not all_healthy:
print("Not all servers are healthy. Shutting down...") logger.error("Not all servers are healthy. Shutting down...")
cleanup_processes(server_processes) cleanup_processes(server_processes)
sys.exit(1) sys.exit(1)
print("All servers are healthy. Starting router...") logger.info("All servers are healthy. Starting router...")
# Update router args with worker URLs # Update router args with worker URLs
router_args.worker_urls = [ router_args.worker_urls = [
...@@ -162,16 +182,17 @@ def main(): ...@@ -162,16 +182,17 @@ def main():
router = launch_router(router_args) router = launch_router(router_args)
if router is None: if router is None:
print("Failed to start router. Shutting down...") logger.error("Failed to start router. Shutting down...")
cleanup_processes(server_processes) cleanup_processes(server_processes)
sys.exit(1) sys.exit(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nReceived shutdown signal...") logger.info("Received shutdown signal...")
except Exception as e: except Exception as e:
print(f"Error occurred: {e}") logger.error(f"Error occurred: {e}")
print(get_exception_traceback()) logger.error(get_exception_traceback())
finally: finally:
logger.info("Cleaning up processes...")
cleanup_processes(server_processes) cleanup_processes(server_processes)
......
...@@ -27,6 +27,7 @@ class Router: ...@@ -27,6 +27,7 @@ class Router:
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60 routing. Default: 60
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
""" """
def __init__( def __init__(
...@@ -40,6 +41,7 @@ class Router: ...@@ -40,6 +41,7 @@ class Router:
balance_rel_threshold: float = 1.0001, balance_rel_threshold: float = 1.0001,
eviction_interval_secs: int = 60, eviction_interval_secs: int = 60,
max_tree_size: int = 2**24, max_tree_size: int = 2**24,
verbose: bool = False,
): ):
self._router = _Router( self._router = _Router(
worker_urls=worker_urls, worker_urls=worker_urls,
...@@ -51,6 +53,7 @@ class Router: ...@@ -51,6 +53,7 @@ class Router:
balance_rel_threshold=balance_rel_threshold, balance_rel_threshold=balance_rel_threshold,
eviction_interval_secs=eviction_interval_secs, eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size, max_tree_size=max_tree_size,
verbose=verbose,
) )
def start(self) -> None: def start(self) -> None:
......
...@@ -22,6 +22,7 @@ struct Router { ...@@ -22,6 +22,7 @@ struct Router {
balance_rel_threshold: f32, balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
verbose: bool,
} }
#[pymethods] #[pymethods]
...@@ -36,7 +37,8 @@ impl Router { ...@@ -36,7 +37,8 @@ impl Router {
balance_abs_threshold = 32, balance_abs_threshold = 32,
balance_rel_threshold = 1.0001, balance_rel_threshold = 1.0001,
eviction_interval_secs = 60, eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24) max_tree_size = 2usize.pow(24),
verbose = false
))] ))]
fn new( fn new(
worker_urls: Vec<String>, worker_urls: Vec<String>,
...@@ -48,6 +50,7 @@ impl Router { ...@@ -48,6 +50,7 @@ impl Router {
balance_rel_threshold: f32, balance_rel_threshold: f32,
eviction_interval_secs: u64, eviction_interval_secs: u64,
max_tree_size: usize, max_tree_size: usize,
verbose: bool,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -59,14 +62,11 @@ impl Router { ...@@ -59,14 +62,11 @@ impl Router {
balance_rel_threshold, balance_rel_threshold,
eviction_interval_secs, eviction_interval_secs,
max_tree_size, max_tree_size,
verbose,
}) })
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let host = self.host.clone();
let port = self.port;
let worker_urls = self.worker_urls.clone();
let policy_config = match &self.policy { let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig, PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
...@@ -80,9 +80,15 @@ impl Router { ...@@ -80,9 +80,15 @@ impl Router {
}; };
actix_web::rt::System::new().block_on(async move { actix_web::rt::System::new().block_on(async move {
server::startup(host, port, worker_urls, policy_config) server::startup(server::ServerConfig {
.await host: self.host.clone(),
.unwrap(); port: self.port,
worker_urls: self.worker_urls.clone(),
policy_config,
verbose: self.verbose,
})
.await
.unwrap();
}); });
Ok(()) Ok(())
......
use clap::Parser; use clap::Parser;
use clap::ValueEnum; use clap::ValueEnum;
use sglang_router_rs::{router::PolicyConfig, server}; use sglang_router_rs::{router::PolicyConfig, server, server::ServerConfig};
#[derive(Debug, Clone, ValueEnum)] #[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType { pub enum PolicyType {
...@@ -89,6 +89,9 @@ struct Args { ...@@ -89,6 +89,9 @@ struct Args {
help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24" help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24"
)] )]
max_tree_size: usize, max_tree_size: usize,
#[arg(long, default_value_t = false, help = "Enable verbose logging")]
verbose: bool,
} }
impl Args { impl Args {
...@@ -111,5 +114,12 @@ impl Args { ...@@ -111,5 +114,12 @@ impl Args {
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let args = Args::parse(); let args = Args::parse();
let policy_config = args.get_policy_config(); let policy_config = args.get_policy_config();
server::startup(args.host, args.port, args.worker_urls, policy_config).await server::startup(ServerConfig {
host: args.host,
port: args.port,
worker_urls: args.worker_urls,
policy_config,
verbose: args.verbose,
})
.await
} }
...@@ -3,6 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; ...@@ -3,6 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{Stream, StreamExt, TryStreamExt}; use futures_util::{Stream, StreamExt, TryStreamExt};
use log::{debug, info};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::hash::Hash; use std::hash::Hash;
...@@ -171,11 +172,11 @@ impl Router { ...@@ -171,11 +172,11 @@ impl Router {
// Print the process queue // Print the process queue
let locked_processed_queue = processed_queue_clone.lock().unwrap(); let locked_processed_queue = processed_queue_clone.lock().unwrap();
println!("Processed Queue: {:?}", locked_processed_queue); info!("Processed Queue: {:?}", locked_processed_queue);
// Print the running queue // Print the running queue
let locked_running_queue = running_queue_clone.lock().unwrap(); let locked_running_queue = running_queue_clone.lock().unwrap();
println!("Running Queue: {:?}", locked_running_queue); info!("Running Queue: {:?}", locked_running_queue);
} }
}); });
...@@ -266,7 +267,7 @@ impl Router { ...@@ -266,7 +267,7 @@ impl Router {
let selected_url = if is_imbalanced { let selected_url = if is_imbalanced {
// Log load balancing trigger and current queue state // Log load balancing trigger and current queue state
println!( info!(
"Load balancing triggered due to workload imbalance:\n\ "Load balancing triggered due to workload imbalance:\n\
Max load: {}, Min load: {}\n\ Max load: {}, Min load: {}\n\
Current running queue: {:?}", Current running queue: {:?}",
...@@ -368,8 +369,7 @@ impl Router { ...@@ -368,8 +369,7 @@ impl Router {
let mut locked_queue = running_queue.lock().unwrap(); let mut locked_queue = running_queue.lock().unwrap();
let count = locked_queue.get_mut(&worker_url).unwrap(); let count = locked_queue.get_mut(&worker_url).unwrap();
*count = count.saturating_sub(1); *count = count.saturating_sub(1);
// print debug!("streaming is done!!")
// println!("streaming is done!!")
} }
}), }),
) )
......
...@@ -2,6 +2,9 @@ use crate::router::PolicyConfig; ...@@ -2,6 +2,9 @@ use crate::router::PolicyConfig;
use crate::router::Router; use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use bytes::Bytes; use bytes::Bytes;
use env_logger::Builder;
use log::{debug, info, LevelFilter};
use std::io::Write;
#[derive(Debug)] #[derive(Debug)]
pub struct AppState { pub struct AppState {
...@@ -125,23 +128,49 @@ async fn v1_completions( ...@@ -125,23 +128,49 @@ async fn v1_completions(
.await .await
} }
pub async fn startup( pub struct ServerConfig {
host: String, pub host: String,
port: u16, pub port: u16,
worker_urls: Vec<String>, pub worker_urls: Vec<String>,
policy_config: PolicyConfig, pub policy_config: PolicyConfig,
) -> std::io::Result<()> { pub verbose: bool,
println!("Starting server on {}:{}", host, port); }
println!("Worker URLs: {:?}", worker_urls);
println!("Policy Config: {:?}", policy_config); pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
Builder::new()
// Create client once with configuration .format(|buf, record| {
use chrono::Local;
writeln!(
buf,
"[Router (Rust)] {} - {} - {}",
Local::now().format("%Y-%m-%d %H:%M:%S"),
record.level(),
record.args()
)
})
.filter(
None,
if config.verbose {
LevelFilter::Debug
} else {
LevelFilter::Info
},
)
.init();
info!("Starting server on {}:{}", config.host, config.port);
info!("Worker URLs: {:?}", config.worker_urls);
info!("Policy Config: {:?}", config.policy_config);
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.build() .build()
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
// Store both worker_urls and client in AppState let app_state = web::Data::new(AppState::new(
let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config)); config.worker_urls,
client,
config.policy_config,
));
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
...@@ -155,7 +184,7 @@ pub async fn startup( ...@@ -155,7 +184,7 @@ pub async fn startup(
.service(health_generate) .service(health_generate)
.service(get_server_info) .service(get_server_info)
}) })
.bind((host, port))? .bind((config.host, config.port))?
.run() .run()
.await .await
} }
use dashmap::mapref::entry::Entry; use dashmap::mapref::entry::Entry;
use dashmap::DashMap; use dashmap::DashMap;
use log::info;
use rand::distributions::{Alphanumeric, DistString}; use rand::distributions::{Alphanumeric, DistString};
use rand::thread_rng; use rand::thread_rng;
use std::cmp::min; use std::cmp::min;
...@@ -434,9 +435,9 @@ impl Tree { ...@@ -434,9 +435,9 @@ impl Tree {
} }
} }
println!("Before eviction - Used size per tenant:"); info!("Before eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant { for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size); info!("Tenant: {}, Size: {}", tenant, size);
} }
// Process eviction // Process eviction
...@@ -490,9 +491,9 @@ impl Tree { ...@@ -490,9 +491,9 @@ impl Tree {
} }
} }
println!("\nAfter eviction - Used size per tenant:"); info!("After eviction - Used size per tenant:");
for (tenant, size) in &used_size_per_tenant { for (tenant, size) in &used_size_per_tenant {
println!("Tenant: {}, Size: {}", tenant, size); info!("Tenant: {}, Size: {}", tenant, size);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment