"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "6450147e76db9c1c8d921c443149078dcf5854b2"
Unverified Commit f9633fa9 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[rust] cache-aware DP - approx tree (#1934)

parent 087ab832
import itertools
import json
import random
import string
import threading
import time
from argparse import ArgumentParser
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
random.seed(42)
def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values())
selected_tokens = random.choices(all_available_tokens, k=token_num)
ret = tokenizer.decode(selected_tokens)
return ret
def gen_arguments(args, tokenizer):
multi_qas = [
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa)
]
for i in range(args.num_qa):
qas = multi_qas[i]["qas"]
for j in range(args.turns):
qas.append(
{
"prompt": gen_prompt(tokenizer, args.len_q),
"new_tokens": args.len_a,
}
)
return multi_qas
@sgl.function
def multi_turns(s, system_prompt, qas):
s += system_prompt
for qa in qas:
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
def main(args):
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
multi_qas = gen_arguments(args, tokenizer)
backend = select_sglang_backend(args)
tic = time.time()
states = multi_turns.run_batch(
multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
print(f"Latency: {latency:.3f}")
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
value = {
"task": "multi_turn_system_prompt_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,
"other": {
"parallel": args.parallel,
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--turns", type=int, default=8)
parser.add_argument("--num-qa", type=int, default=128)
parser.add_argument("--system-prompt-len", type=int, default=2048)
parser.add_argument("--len-q", type=int, default=32)
parser.add_argument("--len-a", type=int, default=128)
parser.add_argument(
"--tokenizer", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct"
)
parser.add_argument("--trust-remote-code", action="store_true")
args = add_common_sglang_args_and_parse(parser)
print(args)
main(args)
This diff is collapsed.
...@@ -21,5 +21,6 @@ bytes = "1.8.0" ...@@ -21,5 +21,6 @@ bytes = "1.8.0"
rand = "0.8.5" rand = "0.8.5"
reqwest = { version = "0.12.8", features = ["stream"] } reqwest = { version = "0.12.8", features = ["stream"] }
futures-util = "0.3" futures-util = "0.3"
serde_json = "=1.0.1" serde_json = "1.0"
pyo3 = { version = "0.22.5", features = ["extension-module"] } pyo3 = { version = "0.22.5", features = ["extension-module"] }
tokenizers = { version = "0.20.3", features = ["http"] }
import argparse
import os
import signal
import subprocess
import sys
import time
from typing import Dict, List
import requests
from sglang_router import PolicyType, Router
# Global processes list for cleanup
_processes: List[subprocess.Popen] = []
def cleanup_processes(signum=None, frame=None):
"""Cleanup function to kill all worker processes."""
print("\nCleaning up processes...")
for process in _processes:
try:
# Kill the entire process group
pgid = os.getpgid(process.pid)
os.killpg(pgid, signal.SIGKILL)
process.wait()
except:
pass
sys.exit(1)
# Register signal handlers
signal.signal(signal.SIGINT, cleanup_processes)
signal.signal(signal.SIGTERM, cleanup_processes)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Launch SGLang Router Server")
parser.add_argument(
"--host", type=str, default="localhost", help="Host address to bind the server"
)
parser.add_argument(
"--port", type=int, default=30000, help="Base port number for workers"
)
parser.add_argument(
"--dp",
type=int,
default=2,
help="Number of worker processes (degree of parallelism)",
)
parser.add_argument(
"--model-path", type=str, required=True, help="Path to the model"
)
parser.add_argument(
"--local-tokenizer-path",
type=str,
required=True,
help="Path to the local tokenizer",
)
return parser.parse_args()
def launch_workers(args) -> tuple[List[subprocess.Popen], List[str]]:
"""Launch all worker processes concurrently using subprocess."""
processes = []
worker_urls = []
# Launch each worker process
for i in range(args.dp):
port = args.port + i
url = f"http://{args.host}:{port}"
worker_urls.append(url)
# TODO: replace this with launch_server, and move this file to sglang/ because it depends on sglang
# We don't
command = f"export CUDA_VISIBLE_DEVICES={i}; python -m sglang.launch_server --model-path {args.model_path} --host {args.host} --port {port}"
print(command)
process = subprocess.Popen(command, shell=True)
processes.append(process)
_processes.append(process) # Add to global list for cleanup
return processes, worker_urls
def wait_for_healthy_workers(worker_urls: List[str], timeout: int = 300) -> bool:
"""Block until all workers are healthy or timeout is reached."""
start_time = time.time()
healthy_workers: Dict[str, bool] = {url: False for url in worker_urls}
while time.time() - start_time < timeout:
print("checking healthiness...")
all_healthy = True
for url in worker_urls:
if not healthy_workers[url]: # Only check workers that aren't healthy yet
try:
response = requests.get(f"{url}/health")
if response.status_code == 200:
print(f"Worker at {url} is healthy")
healthy_workers[url] = True
else:
all_healthy = False
except requests.RequestException:
all_healthy = False
if all_healthy:
print("All workers are healthy!")
return True
time.sleep(5)
# If we get here, we've timed out
unhealthy_workers = [url for url, healthy in healthy_workers.items() if not healthy]
print(f"Timeout waiting for workers: {unhealthy_workers}")
return False
def main():
"""Main function to launch the router and workers."""
args = parse_args()
processes = None
try:
# Launch all workers concurrently
processes, worker_urls = launch_workers(args)
# Block until all workers are healthy
if not wait_for_healthy_workers(worker_urls):
raise RuntimeError("Failed to start all workers")
# Initialize and start the router
router = Router(
worker_urls=worker_urls,
policy=PolicyType.ApproxTree,
tokenizer_path=args.local_tokenizer_path,
)
print("Starting router...")
router.start()
# Keep the main process running
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup: Kill all worker processes
if processes:
for process in processes:
process.kill()
if __name__ == "__main__":
main()
from sglang_router import PolicyType, Router
router = Router(
worker_urls=[
"http://localhost:30000",
"http://localhost:30001",
],
policy=PolicyType.ApproxTree,
tokenizer_path="/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693/tokenizer.json",
)
router.start()
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances. SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances.
## Architecture
1. `src/`: rust impl of the router
2. `py_src/`: lightweight python interafce on top of rust python binding. This will be published as `sglang-router` pypi package
## Installation ## Installation
WIP. Ideally just WIP. Ideally just
...@@ -83,6 +88,23 @@ $ maturin develop ...@@ -83,6 +88,23 @@ $ maturin develop
🛠 Installed sglang_router-0.0.0 🛠 Installed sglang_router-0.0.0
``` ```
4. Alternatively, if you don't want to create a venv, you can also build the binding as a wheel and install it
```bash
$ maturin build --interpreter python
...
Compiling pyo3 v0.22.6
Compiling pyo3-macros v0.22.6
Compiling sglang_router v0.0.0 (/home/jobuser/sglang/rust)
Finished `dev` profile [unoptimized + debuginfo] target(s) in 9.67s
🖨 Copied external shared libraries to package sglang_router.libs directory:
/usr/lib/libssl.so.1.1.1k
/usr/lib/libcrypto.so.1.1.1k
📦 Built wheel for CPython 3.10 to <wheel path>
$ pip install <wheel path>
```
## Usage ## Usage
1. Launch worker instances 1. Launch worker instances
......
sglang @ 760552e0
Subproject commit 760552e068edb58d9cd6e68aa1b714c247027d92
// Python Binding
use pyo3::prelude::*; use pyo3::prelude::*;
pub mod router; pub mod router;
mod server; mod server;
pub mod tree; pub mod tree;
// Python binding #[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
}
#[pyclass] #[pyclass]
struct Router { struct Router {
host: String, host: String,
port: u16, port: u16,
worker_urls: Vec<String>, worker_urls: Vec<String>,
policy: String, policy: PolicyType,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
} }
#[pymethods] #[pymethods]
impl Router { impl Router {
#[new] #[new]
fn new(host: String, port: u16, worker_urls: Vec<String>, policy: String) -> Self { #[pyo3(signature = (
Router { worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
tokenizer_path = None,
cache_threshold = Some(0.50)
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router {
host, host,
port, port,
worker_urls, worker_urls,
policy, policy,
} tokenizer_path,
cache_threshold,
})
} }
fn start(&self) -> PyResult<()> { fn start(&self) -> PyResult<()> {
let host = self.host.clone(); let host = self.host.clone();
let port = self.port; let port = self.port;
let worker_urls = self.worker_urls.clone(); let worker_urls = self.worker_urls.clone();
let policy = self.policy.clone();
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
},
};
actix_web::rt::System::new().block_on(async move { actix_web::rt::System::new().block_on(async move {
server::startup(host, port, worker_urls, policy) server::startup(host, port, worker_urls, policy_config)
.await .await
.unwrap(); .unwrap();
}); });
...@@ -40,9 +89,9 @@ impl Router { ...@@ -40,9 +89,9 @@ impl Router {
} }
} }
// python usage: `from sglang_router import Router`
#[pymodule] #[pymodule]
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> { fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?; m.add_class::<Router>()?;
Ok(()) Ok(())
} }
// src/main.rs // src/main.rs
use clap::builder::PossibleValuesParser;
use clap::Parser; use clap::Parser;
use clap::ValueEnum;
// declare child modules // declare child modules
mod router; mod router;
mod server; mod server;
mod tree; mod tree;
use crate::router::PolicyConfig;
#[derive(Debug, Clone, ValueEnum)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
#[arg(long, default_value = "127.0.0.1")] #[arg(
long,
default_value = "127.0.0.1",
help = "Host address to bind the server to"
)]
host: String, host: String,
#[arg(long, default_value_t = 3001)] #[arg(long, default_value_t = 3001, help = "Port number to listen on")]
port: u16, port: u16,
#[arg(long, value_delimiter = ',')] #[arg(
long,
value_delimiter = ',',
help = "Comma-separated list of worker URLs to distribute requests to"
)]
worker_urls: Vec<String>, worker_urls: Vec<String>,
#[arg(long, default_value = "round_robin", value_parser = PossibleValuesParser::new(&["round_robin", "random"]))] #[arg(
policy: String, long,
default_value_t = PolicyType::RoundRobin,
value_enum,
help = "Load balancing policy to use: random, round_robin, or approx_tree"
)]
policy: PolicyType,
#[arg(
long,
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Path to the tokenizer file, required when using approx_tree policy"
)]
tokenizer_path: Option<String>,
#[arg(
long,
default_value = "0.50",
requires = "policy",
required_if_eq("policy", "approx_tree"),
help = "Cache threshold (0.0-1.0) for approx_tree routing. Routes to cached worker if match rate exceeds threshold, otherwise routes to shortest queue worker"
)]
cache_threshold: Option<f32>,
}
impl Args {
fn get_policy_config(&self) -> PolicyConfig {
match self.policy {
PolicyType::Random => PolicyConfig::RandomConfig,
PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
},
}
}
} }
#[actix_web::main] #[actix_web::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let args = Args::parse(); let args = Args::parse();
server::startup(args.host, args.port, args.worker_urls, args.policy).await let policy_config = args.get_policy_config();
server::startup(args.host, args.port, args.worker_urls, policy_config).await
} }
use crate::tree::RadixTree;
use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse}; use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes; use bytes::Bytes;
use futures_util::TryStreamExt; use futures_util::TryStreamExt;
use std::collections::HashMap;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use tokenizers::tokenizer::Tokenizer;
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
RoundRobin { RoundRobin {
worker_urls: Vec<String>, worker_urls: Vec<String>,
current_index: std::sync::atomic::AtomicUsize, current_index: AtomicUsize,
}, },
Random { Random {
worker_urls: Vec<String>, worker_urls: Vec<String>,
}, },
ApproxTree {
worker_urls: Vec<String>,
// TODO: don't lock the whole tree
url_to_tree: Arc<Mutex<HashMap<String, RadixTree>>>,
tokenizer: Tokenizer,
url_to_count: Arc<Mutex<HashMap<String, usize>>>,
cache_threshold: f32,
},
}
pub enum PolicyConfig {
RandomConfig,
RoundRobinConfig,
ApproxTreeConfig {
tokenizer_path: String,
cache_threshold: f32,
},
}
fn get_token_ids_from_request(body: &Bytes, tokenizer: &Tokenizer) -> Vec<u32> {
// 1. convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
// 3. tokenize the text field
let tokens = tokenizer.encode(text, false).unwrap();
tokens.get_ids().to_vec()
} }
impl Router { impl Router {
pub fn new(worker_urls: Vec<String>, policy: String) -> Self { pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy.to_lowercase().as_str() { match policy_config {
"random" => Router::Random { worker_urls }, PolicyConfig::RandomConfig => Router::Random { worker_urls },
"round_robin" => Router::RoundRobin { PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls, worker_urls,
current_index: std::sync::atomic::AtomicUsize::new(0), current_index: std::sync::atomic::AtomicUsize::new(0),
}, },
_ => panic!( PolicyConfig::ApproxTreeConfig {
"Unknown routing policy: {}. The available policies are 'random' and 'round_robin'", tokenizer_path,
policy cache_threshold,
), } => {
let mut url_to_tree = HashMap::new();
let mut url_to_count = HashMap::new();
for url in &worker_urls {
url_to_tree.insert(url.clone(), RadixTree::new());
url_to_count.insert(url.clone(), 0);
}
Router::ApproxTree {
worker_urls,
url_to_tree: Arc::new(Mutex::new(url_to_tree)),
// TODO: rust ::from_pretrained cannot load from local file, so use ::from_file to load local file
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
url_to_count: Arc::new(Mutex::new(url_to_count)),
cache_threshold,
}
}
} }
} }
pub fn get_first(&self) -> Option<String> { pub fn get_first(&self) -> Option<String> {
match self { match self {
Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } => { Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::ApproxTree { worker_urls, .. } => {
if worker_urls.is_empty() { if worker_urls.is_empty() {
None None
} else { } else {
...@@ -48,26 +100,96 @@ impl Router { ...@@ -48,26 +100,96 @@ impl Router {
req: HttpRequest, req: HttpRequest,
body: Bytes, body: Bytes,
) -> HttpResponse { ) -> HttpResponse {
let mut input_ids: Vec<u32> = Vec::new();
if let Router::ApproxTree { tokenizer, .. } = self {
input_ids = get_token_ids_from_request(&body, tokenizer);
}
let worker_url = match self { let worker_url = match self {
Router::RoundRobin { Router::RoundRobin {
worker_urls, worker_urls,
current_index, current_index,
} => { } => {
current_index let idx = current_index
.fetch_update( .fetch_update(
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst, std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()), |x| Some((x + 1) % worker_urls.len()),
) )
.expect_err("Error updating index in round robin"); .unwrap();
&worker_urls[current_index.load(std::sync::atomic::Ordering::SeqCst)] worker_urls[idx].clone()
} }
Router::Random { worker_urls } => { Router::Random { worker_urls } => {
&worker_urls[rand::random::<usize>() % worker_urls.len()] worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::ApproxTree {
worker_urls,
url_to_tree,
url_to_count,
cache_threshold,
..
} => {
// TODO: pipeline the locks. Release one earlier.
let mut max_matched_rate = 0.0;
let mut max_matched_idx = 0;
let locked_url_to_tree = url_to_tree.lock().unwrap();
// 1. Find the highest matched worker
for (i, url) in worker_urls.iter().enumerate() {
let tree = locked_url_to_tree.get(url).unwrap();
let matched = tree.prefix_match(&input_ids[..]).len();
let matched_rate = matched as f32 / input_ids.len() as f32;
if matched_rate > max_matched_rate {
max_matched_rate = matched_rate;
max_matched_idx = i;
}
}
// 2. If the rate is higher than the threshold, select the worker. If not, select the worker with the shortest queue
if max_matched_rate > *cache_threshold {
worker_urls[max_matched_idx].clone()
} else {
// pick the shortest queue from url_to_count
let locked_url_to_count = url_to_count.lock().unwrap();
let mut min_count = std::usize::MAX;
let mut min_count_id = 0;
for (i, url) in worker_urls.iter().enumerate() {
let count = locked_url_to_count.get(url).unwrap();
if *count < min_count {
min_count = *count;
min_count_id = i;
}
}
worker_urls[min_count_id].clone()
}
} }
}; };
if let Router::ApproxTree {
url_to_tree,
url_to_count,
..
} = self
{
// Insert input_ids to the tree
let mut locked_url_to_tree = url_to_tree.lock().unwrap();
let selected_tree = locked_url_to_tree.get_mut(&worker_url).unwrap();
selected_tree.insert(&input_ids[..]);
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count += 1;
}
// Check if client requested streaming // Check if client requested streaming
let is_stream = serde_json::from_slice::<serde_json::Value>(&body) let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
...@@ -94,11 +216,19 @@ impl Router { ...@@ -94,11 +216,19 @@ impl Router {
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
if !is_stream { if !is_stream {
// TODO: do the correction on the tree based on the cached input_ids
if let Router::ApproxTree { url_to_count, .. } = self {
let mut locked_url_to_count = url_to_count.lock().unwrap();
let count = locked_url_to_count.get_mut(&worker_url).unwrap();
*count -= 1;
}
match res.bytes().await { match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()), Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(), Err(_) => HttpResponse::InternalServerError().finish(),
} }
} else { } else {
// TODO: do the correction on the tree based on the cached input_ids. The streaming might be tricker to handle
HttpResponse::build(status) HttpResponse::build(status)
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
.streaming(res.bytes_stream().map_err(|_| { .streaming(res.bytes_stream().map_err(|_| {
......
use crate::router::PolicyConfig;
use crate::router::Router; use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use bytes::Bytes; use bytes::Bytes;
...@@ -9,9 +10,13 @@ pub struct AppState { ...@@ -9,9 +10,13 @@ pub struct AppState {
} }
impl AppState { impl AppState {
pub fn new(worker_urls: Vec<String>, policy: String, client: reqwest::Client) -> Self { pub fn new(
worker_urls: Vec<String>,
client: reqwest::Client,
policy_config: PolicyConfig,
) -> Self {
// Create router based on policy // Create router based on policy
let router = Router::new(worker_urls, policy); let router = Router::new(worker_urls, policy_config);
Self { router, client } Self { router, client }
} }
...@@ -40,7 +45,6 @@ async fn forward_request( ...@@ -40,7 +45,6 @@ async fn forward_request(
#[get("/v1/models")] #[get("/v1/models")]
async fn v1_model(data: web::Data<AppState>) -> impl Responder { async fn v1_model(data: web::Data<AppState>) -> impl Responder {
// TODO: extract forward_to_route
let worker_url = match data.router.get_first() { let worker_url = match data.router.get_first() {
Some(url) => url, Some(url) => url,
None => return HttpResponse::InternalServerError().finish(), None => return HttpResponse::InternalServerError().finish(),
...@@ -59,7 +63,6 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder { ...@@ -59,7 +63,6 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
} }
// no deser and ser, just forward and return
#[post("/generate")] #[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder { async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router.dispatch(&data.client, req, body).await data.router.dispatch(&data.client, req, body).await
...@@ -69,7 +72,7 @@ pub async fn startup( ...@@ -69,7 +72,7 @@ pub async fn startup(
host: String, host: String,
port: u16, port: u16,
worker_urls: Vec<String>, worker_urls: Vec<String>,
routing_policy: String, policy_config: PolicyConfig,
) -> std::io::Result<()> { ) -> std::io::Result<()> {
println!("Starting server on {}:{}", host, port); println!("Starting server on {}:{}", host, port);
println!("Worker URLs: {:?}", worker_urls); println!("Worker URLs: {:?}", worker_urls);
...@@ -80,7 +83,7 @@ pub async fn startup( ...@@ -80,7 +83,7 @@ pub async fn startup(
.expect("Failed to create HTTP client"); .expect("Failed to create HTTP client");
// Store both worker_urls and client in AppState // Store both worker_urls and client in AppState
let app_state = web::Data::new(AppState::new(worker_urls, routing_policy, client)); let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config));
HttpServer::new(move || { HttpServer::new(move || {
App::new() App::new()
......
use std::collections::HashMap; use std::collections::HashMap;
use std::mem; use std::mem;
#[derive(Clone)] #[derive(Debug)]
pub struct Node { pub struct Node {
pub children: HashMap<usize, Node>, // the key is first id of the child because each child must have unique first id pub children: HashMap<u32, Node>, // the key is first id of the child because each child must have unique first id
pub ids: Vec<usize>, pub ids: Vec<u32>,
pub count: usize, pub count: u32,
} }
#[derive(Debug)]
pub struct RadixTree { pub struct RadixTree {
pub root: Node, pub root: Node,
} }
fn common_prefix_len(a: &[usize], b: &[usize]) -> usize { fn common_prefix_len(a: &[u32], b: &[u32]) -> usize {
let mut i = 0; let mut i = 0;
while i < a.len() && i < b.len() && a[i] == b[i] { while i < a.len() && i < b.len() && a[i] == b[i] {
i += 1; i += 1;
...@@ -37,7 +38,7 @@ impl RadixTree { ...@@ -37,7 +38,7 @@ impl RadixTree {
} }
} }
pub fn insert(&mut self, input_ids: &[usize]) { pub fn insert(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root; let mut curr = &mut self.root;
curr.count += 1; curr.count += 1;
...@@ -93,7 +94,7 @@ impl RadixTree { ...@@ -93,7 +94,7 @@ impl RadixTree {
} }
} }
pub fn prefix_match<'a>(&self, input_ids: &'a [usize]) -> &'a [usize] { pub fn prefix_match<'a>(&self, input_ids: &'a [u32]) -> &'a [u32] {
let mut curr = &self.root; let mut curr = &self.root;
let mut curr_idx = 0; let mut curr_idx = 0;
...@@ -121,7 +122,7 @@ impl RadixTree { ...@@ -121,7 +122,7 @@ impl RadixTree {
&input_ids[..curr_idx] &input_ids[..curr_idx]
} }
pub fn delete(&mut self, input_ids: &[usize]) { pub fn delete(&mut self, input_ids: &[u32]) {
let mut curr = &mut self.root; let mut curr = &mut self.root;
curr.count -= 1; curr.count -= 1;
......
...@@ -67,7 +67,7 @@ fn test_prefix_match_partial() { ...@@ -67,7 +67,7 @@ fn test_prefix_match_partial() {
fn test_prefix_match_no_match() { fn test_prefix_match_no_match() {
let mut tree = RadixTree::new(); let mut tree = RadixTree::new();
tree.insert(&[1, 2, 3, 4]); tree.insert(&[1, 2, 3, 4]);
let empty_slices: &[usize] = &[]; let empty_slices: &[u32] = &[];
assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices); assert_eq!(tree.prefix_match(&[5, 6, 7]), empty_slices);
} }
...@@ -124,7 +124,7 @@ fn test_delete_nonexistent() { ...@@ -124,7 +124,7 @@ fn test_delete_nonexistent() {
#[test] #[test]
fn test_empty_input() { fn test_empty_input() {
let mut tree = RadixTree::new(); let mut tree = RadixTree::new();
let empty_slice: &[usize] = &[]; let empty_slice: &[u32] = &[];
tree.insert(empty_slice); tree.insert(empty_slice);
assert_eq!(tree.prefix_match(empty_slice), empty_slice); assert_eq!(tree.prefix_match(empty_slice), empty_slice);
tree.delete(empty_slice); // Should not panic tree.delete(empty_slice); // Should not panic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment