Commit e97493eb authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: sglang backend for tio (#271)

- Setup venv

```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```

- Build: `cargo build --release --features sglang`

- Run single node (make sure you're in the venv): `./tio out=sglang ~/llm_models/my_model`

- Run Deepseek multi-gpu / multi-node:

Node 1:
```
tio in=http out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 0 --dist-init-addr 10.217.98.122:9876
```

Node 2:
```
tio in=none out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 1 --dist-init-addr 10.217.98.122:9876
```
parent c70de37f
......@@ -1486,6 +1486,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inlinable_string"
version = "0.1.15"
......@@ -1501,6 +1507,12 @@ dependencies = [
"libc",
]
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.11.0"
......@@ -1693,6 +1705,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
......@@ -2339,6 +2360,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2677,6 +2761,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
......@@ -3425,12 +3522,15 @@ dependencies = [
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"strum",
"thiserror 2.0.11",
......@@ -3569,6 +3669,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "untrusted"
version = "0.9.0"
......
......@@ -1595,6 +1595,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inlinable_string"
version = "0.1.15"
......@@ -1616,6 +1622,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.11.0"
......@@ -1822,6 +1834,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
......@@ -2479,6 +2500,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2817,6 +2901,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
......@@ -3597,12 +3694,15 @@ dependencies = [
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"strum",
"thiserror 2.0.11",
......@@ -3741,6 +3841,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "untrusted"
version = "0.9.0"
......
......@@ -2309,6 +2309,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inlinable_string"
version = "0.1.15"
......@@ -2345,6 +2351,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.11.0"
......@@ -2616,6 +2628,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "metal"
version = "0.27.0"
......@@ -3632,6 +3653,69 @@ dependencies = [
"reborrow",
]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]]
name = "qoi"
version = "0.4.1"
......@@ -4171,6 +4255,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
......@@ -5181,6 +5278,7 @@ dependencies = [
"anyhow",
"async-stream",
"async-trait",
"async_zmq",
"axum 0.8.1",
"blake3",
"bs62",
......@@ -5193,13 +5291,16 @@ dependencies = [
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"mistralrs",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"strum 0.27.1",
"thiserror 2.0.11",
......@@ -5338,6 +5439,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
......
......@@ -22,6 +22,7 @@ homepage = "https://github.com/triton-inference-server/triton_distributed"
[features]
mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"]
......
......@@ -62,3 +62,12 @@ The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has
Run `tio --help` for more options.
## sglang
```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```
......@@ -14,6 +14,7 @@
// limitations under the License.
use std::path::PathBuf;
use std::str::FromStr;
use triton_distributed_llm::{
backend::ExecutionContext,
......@@ -29,6 +30,8 @@ use triton_distributed_llm::{
use triton_distributed_runtime::{component::Client, protocols::Endpoint, DistributedRuntime};
mod input;
#[cfg(feature = "sglang")]
mod net;
mod opt;
mod output;
pub use opt::{Input, Output};
......@@ -58,6 +61,53 @@ pub struct Flags {
/// The name of the model we are serving
#[arg(long)]
pub model_name: Option<String>,
/// sglang only
///
/// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs.
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub tensor_parallel_size: u32,
/// sglang only
///
/// Use GPUs from this ID upwards.
/// If your machine has four GPUs but the first two (0 and 1) are in use,
/// pass --base-gpu-id 2 to use the third GPU (and up, if tensor_parallel_size > 1)
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..256))]
pub base_gpu_id: u32,
/// sglang only
///
/// How many nodes/hosts to use
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub num_nodes: u32,
/// sglang only
///
/// This nodes' unique ID, running from 0 to num_nodes.
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..255))]
pub node_rank: u32,
/// sglang only
///
/// The Torch Distributed init method address, in format <host>:<port>.
/// It becomes "tcp://<host>:<port>" when given to torch.distributed.init_process_group.
/// This expects to use the nccl backend (transparently to us here).
/// All nodes must use the same dist_init_addr, which is node_rank == 0's address.
#[arg(long)]
pub dist_init_addr: Option<String>,
/// Internal use only.
/// Start the sglang Python sub-process.
/// The params in the tuple are:
/// - the fd of the write end of a pipe where sglang will signal that it's ready.
/// - the node rank (0 for first host, 1 for second host, etc)
/// - the workers' rank (globally unique)
/// - the GPU to use (locally unique)
#[arg(long)]
#[clap(hide = true, value_parser = parse_sglang_flags)]
pub internal_sglang_process: Option<SgLangFlags>,
}
pub enum EngineConfig {
......@@ -79,11 +129,36 @@ pub enum EngineConfig {
},
}
#[derive(Debug, Clone, Copy)]
pub struct SgLangFlags {
pub pipe_fd: u32,
pub tp_rank: u32,
pub gpu_id: u32,
}
fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
let nums: Vec<u32> = s
.split(',')
.map(u32::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
if nums.len() != 3 {
return Err("Need exactly 3 numbers".into());
}
Ok(SgLangFlags {
pipe_fd: nums[0],
tp_rank: nums[1],
gpu_id: nums[2],
})
}
pub async fn run(
runtime: triton_distributed_runtime::Runtime,
in_opt: Input,
out_opt: Output,
flags: Flags,
#[allow(unused_variables)] zmq_socket_prefix: Option<String>,
) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token();
......@@ -109,6 +184,9 @@ pub async fn run(
Some(_) | None => None,
};
#[cfg(feature = "sglang")]
let mut extra = None; // sglang sub-process
// Create the engine matching `out`
let engine_config = match out_opt {
Output::EchoFull => {
......@@ -174,6 +252,49 @@ pub async fn run(
.await?,
}
}
#[cfg(feature = "sglang")]
Output::SgLang => {
use triton_distributed_llm::engines::sglang;
let Some(model_path) = model_path else {
anyhow::bail!("out=sglang requires flag --model-path=<full-path-to-model-dir>");
};
if !model_path.is_dir() {
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
// Safety: Earlier we build maybe_card from model_path, which we checked right above
let card = maybe_card.clone().unwrap();
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("sglang requires zmq_socket_prefix");
};
let node_conf = sglang::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
tracing::info!("If you see 'gloo' errors from sglang try setting these environment variables:");
tracing::info!("export GLOO_SOCKET_IFNAME={if_name}");
tracing::info!("export NCCL_SOCKET_IFNAME={if_name}");
}
}
let (engine, sglang_process) = sglang::make_engine(
cancel_token.clone(),
&model_path,
&sock_prefix,
node_conf,
flags.tensor_parallel_size,
flags.base_gpu_id,
)
.await?;
extra = Some(sglang_process);
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
};
match in_opt {
......@@ -186,6 +307,19 @@ pub async fn run(
Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
}
Input::None => {
// Multi-node setup. The engine sub-process has been started and is talking
// to it's node_rank 0 controller. We do nothing.
// TODO: Acquire an etcd lease, we are running
cancel_token.cancelled().await;
}
}
#[cfg(feature = "sglang")]
// Allow engines to ask main thread to wait on an extra future.
// sglang uses this to shut down sub-process
if let Some(extra) = extra {
extra.await?;
}
Ok(())
......
......@@ -39,11 +39,52 @@ const DEFAULT_OUT: Output = Output::MistralRs;
#[cfg(not(feature = "mistralrs"))]
const DEFAULT_OUT: Output = Output::EchoFull;
const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]";
const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> {
logging::init();
// Call sub-processes before starting the Runtime machinery
// For anything except sub-process starting try_parse_from will error.
if let Ok(flags) = tio::Flags::try_parse_from(env::args()) {
#[allow(unused_variables)]
if let Some(sglang_flags) = flags.internal_sglang_process {
let Some(model_path) = flags.model_path_flag.as_ref() else {
anyhow::bail!("sglang subprocess requires --model-path");
};
if !model_path.is_dir() {
anyhow::bail!("sglang subprocess requires model path to be a directory containing the safetensors files");
}
if cfg!(feature = "sglang") {
#[cfg(feature = "sglang")]
{
use triton_distributed_llm::engines::sglang;
let gpu_config = sglang::MultiGPUConfig {
tp_size: flags.tensor_parallel_size,
tp_rank: sglang_flags.tp_rank,
gpu_id: sglang_flags.gpu_id,
};
let node_config = sglang::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
};
return sglang::run_subprocess(
ZMQ_SOCKET_PREFIX,
model_path,
sglang_flags.pipe_fd as std::os::fd::RawFd,
node_config,
gpu_config,
);
}
} else {
panic!("Rebuild with --features=sglang");
}
}
}
// max_worker_threads and max_blocking_threads from env vars or config file.
let rt_config = triton_distributed_runtime::RuntimeConfig::from_settings()?;
......@@ -103,5 +144,12 @@ async fn tio_wrapper(runtime: triton_distributed_runtime::Runtime) -> anyhow::Re
.chain(env::args().skip(non_flag_params)),
)?;
tio::run(runtime, in_opt, out_opt, flags).await
tio::run(
runtime,
in_opt,
out_opt,
flags,
Some(ZMQ_SOCKET_PREFIX.to_string()),
)
.await
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::TryStreamExt;
use netlink_packet_route::address::AddressAttribute;
use netlink_packet_route::link::LinkLayerType;
use netlink_packet_route::link::State as LinkState;
use netlink_packet_route::link::{LinkAttribute, LinkMessage};
use netlink_packet_route::AddressFamily;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::{collections::HashMap, error::Error};
pub async fn get_primary_interface() -> Result<Option<String>, LinkDataError> {
let mut candidates: VecDeque<String> = get_ipv4_interface_links()
.await?
.into_iter()
.filter(|(k, v)| v.is_ethernet() && v.link_is_up() && v.has_carrier() && k.starts_with("e"))
.map(|(k, _)| k)
.collect();
Ok(candidates.pop_front())
}
#[derive(Clone, Debug)]
// Most of the fields are Option<T> because the netlink protocol allows them
// to be absent (even though we have no reason to believe they'd ever actually
// be missing).
struct InterfaceLinkData {
link_type: LinkLayerType,
state: Option<LinkState>,
has_carrier: bool,
}
impl InterfaceLinkData {
pub fn link_is_up(&self) -> bool {
self.state
.map(|state| matches!(state, LinkState::Up))
.unwrap_or(false)
}
pub fn is_ethernet(&self) -> bool {
matches!(self.link_type, LinkLayerType::Ether)
}
pub fn has_carrier(&self) -> bool {
self.has_carrier
}
}
impl From<LinkMessage> for InterfaceLinkData {
fn from(link_message: LinkMessage) -> Self {
let link_type = link_message.header.link_layer_type;
let state = link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::OperState(state) => Some(*state),
_ => None,
});
let has_carrier = link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::Carrier(1) => Some(true),
_ => None,
})
.unwrap_or(false);
InterfaceLinkData {
link_type,
state,
has_carrier,
}
}
}
#[derive(Debug)]
pub struct LinkDataError {
kind: LinkDataErrorKind,
interface: Option<String>,
}
impl LinkDataError {
fn connection(connection_error: std::io::Error) -> Self {
let kind = LinkDataErrorKind::Connection(connection_error);
let interface = None;
Self { kind, interface }
}
fn communication(communication_error: rtnetlink::Error) -> Self {
let kind = LinkDataErrorKind::Communication(communication_error);
let interface = None;
Self { kind, interface }
}
}
impl std::fmt::Display for LinkDataError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let err_message = "could not get interface link data";
if let Some(interface) = self.interface.as_ref() {
write!(f, "{err_message} for {interface}")
} else {
write!(f, "{err_message}")
}
}
}
impl Error for LinkDataError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self.kind {
LinkDataErrorKind::Connection(ref e) => Some(e),
LinkDataErrorKind::Communication(ref e) => Some(e),
}
}
}
#[derive(Debug)]
pub enum LinkDataErrorKind {
Connection(std::io::Error),
Communication(rtnetlink::Error),
}
// Retrieve the link data (state, MTU, etc.) for all interfaces, and return
// them as a HashMap keyed by interface name. This is roughly equivalent to `ip
// link show` since we're using the same netlink interface under the hood as
// that command.
async fn get_ipv4_interface_links() -> Result<HashMap<String, InterfaceLinkData>, LinkDataError> {
let (netlink_connection, rtnetlink_handle, _receiver) =
rtnetlink::new_connection().map_err(LinkDataError::connection)?;
// We have to spawn off the netlink connection because of the architecture
// of `netlink_proto::Connection`, which runs in the background and owns
// the socket. We communicate with it via channel messages, and it will exit
// when both `rtnetlink_handle` and `_receiver` go out of scope.
tokio::spawn(netlink_connection);
let address_handle = rtnetlink_handle.address().get().execute();
let ipv4s: HashSet<String> = address_handle
.try_filter_map(|addr_message| async move {
if matches!(addr_message.header.family, AddressFamily::Inet) {
Ok(addr_message
.attributes
.into_iter()
.find(|attr| matches!(attr, AddressAttribute::Label(_)))
.and_then(|x| match x {
AddressAttribute::Label(label) => Some(label),
_ => None,
}))
} else {
Ok(None)
}
})
.try_collect()
.await
.map_err(LinkDataError::communication)?;
let link_handle = rtnetlink_handle.link().get().execute();
link_handle
.try_filter_map(|link_message| async {
let maybe_interface_data = match extract_interface_name(&link_message) {
Some(interface_name) => {
if ipv4s.contains(&interface_name) {
Some((interface_name, InterfaceLinkData::from(link_message)))
} else {
None
}
}
None => {
let idx = link_message.header.index;
eprintln!(
"Network interface with index {idx} doesn't have a name (no IfName attribute)"
);
None
}
};
Ok(maybe_interface_data)
})
.try_collect()
.await
.map_err(LinkDataError::communication)
}
fn extract_interface_name(link_message: &LinkMessage) -> Option<String> {
link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::IfName(name) => Some(name.clone()),
_ => None,
})
}
......@@ -26,6 +26,11 @@ pub enum Input {
/// Pull requests from a namespace/component/endpoint path.
Endpoint(String),
/// Start the engine but don't provide any way to talk to it.
/// For multi-node sglang, where the engine connects directly
/// to the co-ordinator via torch distributed / nccl.
None,
}
impl TryFrom<&str> for Input {
......@@ -35,6 +40,7 @@ impl TryFrom<&str> for Input {
match s {
"http" => Ok(Input::Http),
"text" => Ok(Input::Text),
"none" => Ok(Input::None),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Input::Endpoint(path.to_string()))
......@@ -50,6 +56,7 @@ impl fmt::Display for Input {
Input::Http => "http",
Input::Text => "text",
Input::Endpoint(path) => path,
Input::None => "none",
};
write!(f, "{s}")
}
......@@ -68,6 +75,10 @@ pub enum Output {
#[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle
MistralRs,
#[cfg(feature = "sglang")]
/// Run inference using sglang
SgLang,
}
impl TryFrom<&str> for Output {
......@@ -78,6 +89,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "mistralrs")]
"mistralrs" => Ok(Output::MistralRs),
#[cfg(feature = "sglang")]
"sglang" => Ok(Output::SgLang),
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
......@@ -97,6 +111,9 @@ impl fmt::Display for Output {
#[cfg(feature = "mistralrs")]
Output::MistralRs => "mistralrs",
#[cfg(feature = "sglang")]
Output::SgLang => "sglang",
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
......
......@@ -1581,6 +1581,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inlinable_string"
version = "0.1.15"
......@@ -1602,6 +1608,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.11.0"
......@@ -1815,6 +1827,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
......@@ -2461,6 +2482,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.96",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.96",
]
[[package]]
name = "quote"
version = "1.0.38"
......@@ -2800,6 +2884,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.217"
......@@ -3554,12 +3651,15 @@ dependencies = [
"galil-seiferas",
"indexmap 2.7.0",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"strum",
"thiserror 2.0.11",
......@@ -3698,6 +3798,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "untrusted"
version = "0.9.0"
......
......@@ -1634,6 +1634,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.11.0"
......@@ -2927,6 +2933,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.217"
......@@ -3681,12 +3700,15 @@ dependencies = [
"galil-seiferas",
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"prometheus",
"pyo3",
"regex",
"semver",
"serde",
"serde-pickle",
"serde_json",
"strum",
"thiserror 2.0.11",
......
......@@ -2374,6 +2374,12 @@ dependencies = [
"web-time",
]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]]
name = "inlinable_string"
version = "0.1.15"
......@@ -2429,6 +2435,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]]
name = "itertools"
version = "0.10.5"
......@@ -2715,6 +2727,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]]
name = "metal"
version = "0.27.0"
......@@ -3755,6 +3776,69 @@ dependencies = [
"reborrow",
]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]]
name = "qoi"
version = "0.4.1"
......@@ -4427,6 +4511,19 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.217"
......@@ -5434,6 +5531,7 @@ dependencies = [
"anyhow",
"async-stream",
"async-trait",
"async_zmq",
"axum 0.8.1",
"blake3",
"bs62",
......@@ -5448,17 +5546,20 @@ dependencies = [
"indexmap 2.7.1",
"insta",
"itertools 0.14.0",
"libc",
"minijinja",
"minijinja-contrib",
"mistralrs",
"prometheus",
"proptest",
"pyo3",
"regex",
"reqwest",
"rstest",
"semver",
"sentencepiece",
"serde",
"serde-pickle",
"serde_json",
"strum 0.27.1",
"tempfile",
......@@ -5610,6 +5711,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"
......
......@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"]
metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"]
sglang = ["dep:async_zmq"]
[workspace.dependencies]
# local or crates.io
......@@ -81,6 +82,7 @@ xxhash-rust = { workspace = true }
strum = { workspace = true }
blake3 = "1"
regex = "1"
# protocols
chrono = { version = "0.4", default-features = false, features = [
......@@ -91,7 +93,6 @@ chrono = { version = "0.4", default-features = false, features = [
"serde",
] }
serde_json = { version = "1" }
regex = "1"
unicode-segmentation = "1.12"
# http-service
......@@ -103,6 +104,17 @@ either = { version = "1.13" }
indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true }
# sglang
async_zmq = { version = "0.4.0", optional = true }
libc = "0.2"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
serde-pickle = "1.2.0"
# tokenizers
tokenizers = { version = "0.21.0", default-features = false, features = [
"onig",
......
......@@ -15,3 +15,6 @@
#[cfg(feature = "mistralrs")]
pub mod mistralrs;
#[cfg(feature = "sglang")]
pub mod sglang;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use crate::backend::ExecutionContext;
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;
mod worker;
mod engine;
use engine::SgLangEngine;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
cancel_token: CancellationToken,
// Full path to the model directory
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
// Multi node settings:
// - num_nodes: How many nodes/hosts we are using
// - node_rank: Unique consecutive int starting at 0 to identify this node
// - dist_init_addr: Torch Distributed init method addr:port
node_conf: MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// The base GPU ID to start allocating GPUs from
base_gpu_id: u32,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = SgLangEngine::new(
cancel_token,
sock_code,
model_path,
node_conf,
tensor_parallel_size,
base_gpu_id,
)
.await?;
let sglang_process = engine.take_sglang_worker_handle();
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, sglang_process))
}
#[derive(Debug, Clone, Copy)]
pub struct MultiGPUConfig {
/// How many GPUs we are using / how many processes
pub tp_size: u32,
/// Tensor Parallel Rank. Must be unique across all nodes and GPUs.
pub tp_rank: u32,
/// GPU ID. Which GPU to run on. In single-node setup this is the same as tp_rank.
pub gpu_id: u32,
}
impl Default for MultiGPUConfig {
fn default() -> Self {
MultiGPUConfig {
tp_size: 1,
tp_rank: 0,
gpu_id: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiNodeConfig {
/// How many nodes / hosts we are using
pub num_nodes: u32,
/// Unique consecutive integer to identify this node
pub node_rank: u32,
/// host:port of head / control node
pub dist_init_addr: Option<String>,
}
impl Default for MultiNodeConfig {
fn default() -> Self {
MultiNodeConfig {
num_nodes: 1,
node_rank: 0,
dist_init_addr: None,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use async_stream::stream;
use async_trait::async_trait;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
use crate::engines::sglang::MultiNodeConfig;
pub struct SgLangEngine {
cancel_token: CancellationToken,
worker: super::worker::SgLangWorker,
}
impl SgLangEngine {
pub async fn new(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
base_gpu_id: u32,
) -> anyhow::Result<Self> {
let w = super::worker::start(
cancel_token.clone(),
sock_code,
model_path,
node_conf,
tensor_parallel_size,
base_gpu_id,
)
.await?;
let engine = SgLangEngine {
cancel_token,
worker: w,
};
Ok(engine)
}
pub fn take_sglang_worker_handle(&mut self) -> tokio::task::JoinHandle<()> {
self.worker.take_sglang_handle()
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for SgLangEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(128);
let work_req = super::worker::WorkRequest {
request_id: context.id().to_string(),
request,
response_channel: resp_tx,
};
self.worker.enqueue_request(work_req).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_resp_rx = resp_rx.recv() => {
match maybe_resp_rx {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::{os::fd::RawFd, path::Path};
const PY_START_ENGINE: &std::ffi::CStr = cr#"
from multiprocessing.connection import Connection
import signal
import tempfile
import logging
from sglang.srt.server_args import ServerArgs, PortArgs
import sglang as sgl
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.entrypoints.engine import _set_envs_and_config
server_args = ServerArgs(
model_path=f"{model_path}",
enable_metrics = False,
log_level = "debug",
log_requests = True,
tp_size = int(tp_size_str),
# Multi-node
dist_init_addr = dist_init_addr if dist_init_addr != "" else None,
nnodes = int(nnodes_str),
node_rank = int(node_rank_str),
)
logging.basicConfig(
level="DEBUG",
force=True,
datefmt="%Y-%m-%d %H:%M:%S",
format=f"[%(asctime)s] %(message)s",
)
_set_envs_and_config(server_args)
logging.debug(server_args)
ipc_path = f"ipc:///tmp/{socket_id}";
# These must match worker.rs zmq_sockets, which is the other side
port_args = PortArgs(
# we don't use this one so use anything
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
# Us -> sglang
scheduler_input_ipc_name=f"{ipc_path}_input_socket",
# sglang -> us
detokenizer_ipc_name=f"{ipc_path}_output_socket",
# The port for nccl initialization (torch.dist), which we don't use
nccl_port=9876,
)
# Rank must be globally unique across nodes
tp_rank = int(tp_rank_str)
# See nvidia-smi for GPU IDs, they run 0,1,2,etc.
# In a single-node setup this is the same as rank
gpu_id = int(gpu_id_str)
pipe_fd_int = int(pipe_fd)
writer = Connection(handle=pipe_fd_int, readable=False, writable=True)
run_scheduler_process(server_args, port_args, gpu_id, tp_rank, None, writer)
"#;
/// Start the Python sglang engine that listens on zmq socket
/// This is called by running `nio --internal-sglang-process
/// This does not return until the subprocess exits.
pub fn run_subprocess(
// The prefix to put on the zmq socket names
socket_id: &str,
// Directory containing an HF repo with safetensors files, tokenizer, etc
model_path: &Path,
// The write half of a pipe, where sglang will signal when it's ready
notify_pipe_fd: RawFd,
// Multi node. Usually Default::default
node_config: super::MultiNodeConfig,
// Multi GPU. Usually Default::default
gpu_config: super::MultiGPUConfig,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let dir = model_path.display().to_string();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("model_path", dir.as_str()),
("pipe_fd", &notify_pipe_fd.to_string()),
// to_string because slice must all be the same type
("tp_size_str", &gpu_config.tp_size.to_string()),
("tp_rank_str", &gpu_config.tp_rank.to_string()),
("gpu_id_str", &gpu_config.gpu_id.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("node_rank_str", &node_config.node_rank.to_string()),
(
"dist_init_addr",
&node_config.dist_init_addr.unwrap_or_default().to_string(),
),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(PY_START_ENGINE, None, Some(&locals)) {
anyhow::bail!("sglang engine run error: {err}");
}
tracing::info!("sglang subprocess exit");
Ok(())
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
fmt,
os::fd::{FromRawFd as _, RawFd},
path::Path,
process::Stdio,
sync::Arc,
time::Duration,
vec::IntoIter,
};
use anyhow::Context as _;
use async_zmq::{SinkExt, StreamExt};
use libc::c_int;
use pyo3::{
exceptions::PyTypeError,
prelude::*,
types::{IntoPyDict, PyBytes, PyString},
};
use regex::Regex;
use tokio::sync::mpsc::Sender;
use tokio::{io::AsyncBufReadExt, sync::mpsc::error::SendError};
use tokio::{io::AsyncReadExt as _, task::JoinHandle};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
use crate::engines::sglang::{MultiGPUConfig, MultiNodeConfig};
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::common::preprocessor::PreprocessedRequest;
use crate::protocols::common::FinishReason;
use crate::protocols::TokenIdType;
/// If user does not provide a max_tokens limit to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Wait this long for the sglang sub-process to stop after we send it a KILL
const SGLANG_STOP_TIMEOUT: Duration = Duration::from_millis(1500);
/// Match sglang python log entries, e.g "[2025-01-30 11:23:16] Some text we want"
const SGLANG_LOG_RE: &str = r"(\[\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\] )?(.*)";
/// Identify sglang log entries with this prefix
const LOG_PREFIX: &str = "SGLANG";
/// Part of what sglang sends us over it's pipe when it's ready
const READY_BYTES: [u8; 5] = [b'r', b'e', b'a', b'd', b'y'];
type RequestID = String;
pub struct SgLangWorker {
/// How we receive work requests
tx: Sender<WorkRequest>,
/// Handle of the task that reads from `tx` and forwards those requests over zmq to vllm
_input_loop: JoinHandle<()>,
/// Handle of the task that reads sglang's responses from zmq and dispatches them to the correct
/// active request.
_output_loop: JoinHandle<()>,
/// Handle of the vllm background process
sglang: Option<JoinHandle<()>>,
// We don't need to hold on to this, it's already shared between input_loop and output_loop
// But later we'll probably want stats - how many active requests etc, so keep it here
_active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
// Need to keep this alive
// TODO: With async_zmq we possibly don't need this at all
#[allow(dead_code)]
zmq_context: async_zmq::Context,
}
/// How we get asked to do some work. These get unpacked and forwarded to vllm.
pub struct WorkRequest {
pub request: PreprocessedRequest,
pub request_id: RequestID,
pub response_channel: Sender<Annotated<LLMEngineOutput>>,
}
/// A request currently being process by vllm
struct ActiveRequest {
tx: Sender<Annotated<LLMEngineOutput>>,
num_output_tokens_so_far: Option<i32>,
max_tokens: i32,
}
/// Python imports
struct Imports {
pickle_module: PyObject,
sampling_params_type: PyObject,
rpc_type: PyObject,
}
/// All the zmq sockets we used. This object only used to passing them around to avoid large
/// tuples.
struct Sockets {
#[allow(dead_code)]
context: async_zmq::Context, // we have to keep this alive
// Requests from us to the sglang engine
// scheduler_input_ipc_name,
input: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
// Responses from the sglang engine back to us
// tokenizer_ipc_name
output: async_zmq::Pull,
}
/// What sglang sends us.
#[allow(dead_code)]
#[derive(FromPyObject, Debug)]
pub struct BatchTokenIDOut {
// The request id
rids: Vec<String>,
// The finish reason
// sglang implements finish reason as subclasses of BaseFinishReason
// e.g. `class FINISH_LENGTH(BaseFinishReason):` and lots of others
finished_reasons: Vec<Option<SgLangFinishReason>>,
// For incremental decoding
// The version id to sync decode status with in detokenizer_manager
vids: Vec<i32>,
decoded_texts: Vec<String>,
decode_ids: Vec<Vec<u32>>,
read_offsets: Vec<i32>,
// Only used when `--skip-tokenizer-init` is on
output_ids: Option<Vec<i32>>,
// Detokenization configs
skip_special_tokens: Vec<bool>,
spaces_between_special_tokens: Vec<bool>,
no_stop_trim: Vec<bool>,
// Token counts
prompt_tokens: Vec<i32>,
completion_tokens: Vec<i32>,
cached_tokens: Vec<i32>,
spec_verify_ct: Vec<i32>,
// Logprobs
input_token_logprobs_val: Option<Vec<f64>>,
input_token_logprobs_idx: Option<Vec<i32>>,
output_token_logprobs_val: Option<Vec<f64>>,
output_token_logprobs_idx: Option<Vec<i32>>,
// These in Python are all `List[List]`, so guess
input_top_logprobs_val: Option<Vec<Vec<f64>>>,
input_top_logprobs_idx: Option<Vec<Vec<i32>>>,
output_top_logprobs_val: Option<Vec<Vec<f64>>>,
output_top_logprobs_idx: Option<Vec<Vec<i32>>>,
}
#[derive(Debug, Copy, Clone)]
enum SgLangFinishReason {
Matched,
Length,
Abort,
}
impl fmt::Display for SgLangFinishReason {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
SgLangFinishReason::Matched => write!(f, "Finished due to a successful match"),
SgLangFinishReason::Length => {
write!(f, "Finished due to reaching the specified length")
}
SgLangFinishReason::Abort => write!(f, "Operation was aborted"),
}
}
}
impl<'py> FromPyObject<'py> for SgLangFinishReason {
fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
// The object we have is a subclass of sglang's BaseFinishReason, one subclass
// per finish reason. I don't know how to identify the class, but if we force
// it to a string I _think_ it ends up calling `json_str` in the subclass.
// Also the string uses single quotes in the JSON, I don't know why.
let json_str = obj.str()?.to_string().replace("'", "\"");
let as_map: HashMap<String, serde_json::Value> =
serde_json::from_str(&json_str).map_err(|err| {
tracing::error!("SgLangFinishReason JSON convert err: {err}. JSON: {json_str}");
PyTypeError::new_err(format!("serde_json err: {err}. JSON: {json_str}"))
})?;
let Some(type_serde) = as_map.get("type") else {
return Err(PyTypeError::new_err("Finish reason missing 'type' JSON field. See sglang's schedule_batch.py BaseFinishReason"));
};
let Some(type_str) = type_serde.as_str() else {
return Err(PyTypeError::new_err("Finish reason 'type' JSON field is not a string. See sglang's schedule_batch.py BaseFinishReason"));
};
match type_str {
"stop" => Ok(SgLangFinishReason::Matched),
"length" => Ok(SgLangFinishReason::Length),
"abort" => Ok(SgLangFinishReason::Abort),
x => {
tracing::warn!("Unknown sglang BaseFinishReason type '{x}'. Using Abort instead.");
Ok(SgLangFinishReason::Abort)
}
}
}
}
impl From<SgLangFinishReason> for FinishReason {
fn from(sfr: SgLangFinishReason) -> Self {
use SgLangFinishReason::*;
match sfr {
Matched => FinishReason::Stop,
Length => FinishReason::Length,
Abort => FinishReason::Cancelled, // or FinishReason::Error ?
}
}
}
/* What we send to sglang
class TokenizedGenerateReqInput:
# The request id
rid: str
# The input text
input_text: str
# The input token ids
input_ids: List[int]
# The image inputs
image_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs
return_logprob: bool
# If return logprobs, the start location in the prompt for returning logprobs.
logprob_start_len: int
# If return logprobs, the number of top logprobs to return at each position.
top_logprobs_num: int
# Whether to stream output
stream: bool
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session info for continual prompting
session_params: Optional[SessionParams] = None
# Custom logit processor for advanced sampling control. Must be a serialized instance
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
# Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None
class SamplingParams:
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
n: int = 1,
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
custom_params: Optional[Dict[str, Any]] = None,
*/
/// Main entry point
pub async fn start(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
node_conf: MultiNodeConfig,
tp_size: u32,
base_gpu_id: u32,
) -> anyhow::Result<SgLangWorker> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let Sockets {
context,
input,
output,
} = zmq_sockets(sock_code)?;
let py_imports = Arc::new(python_imports());
if tp_size < node_conf.num_nodes {
anyhow::bail!("Need at least as many GPUs as nodes. In nio set --tensor-parallel-size >= --num-nodes.");
}
let tp_size_per_node = tp_size / node_conf.num_nodes;
let tp_rank_start = tp_size_per_node * node_conf.node_rank;
let tp_rank_end = tp_size_per_node * (node_conf.node_rank + 1);
// Start all the sglang workers. They communicate amongst themselves using torch distributed
// and nccl. They must all start at once.
let mut sglang_join_handle = None;
let mut process_group = Vec::with_capacity(tp_size as usize);
for tp_rank in tp_rank_start..tp_rank_end {
let gpu_id = base_gpu_id + tp_rank % tp_size_per_node;
let gpu_conf = MultiGPUConfig {
tp_size,
tp_rank,
gpu_id,
};
let (sglang_process, ready_fd) =
start_sglang(model_path, node_conf.clone(), gpu_conf).await?;
process_group.push((tp_rank, ready_fd));
let watcher_join_handle = watch_sglang(cancel_token.clone(), sglang_process);
// TODO: Do we want to hold on to this?
// Do we need it for the other sub-processes?
if sglang_join_handle.is_none() {
sglang_join_handle = Some(watcher_join_handle);
}
}
for (tp_rank, ready_fd) in process_group.into_iter() {
wait_for_sglang(tp_rank, ready_fd).await?;
}
let active_requests = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let (tx, rx) = tokio::sync::mpsc::channel(8);
let input_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(input_loop(
cancel_token,
py_imports,
input,
active_requests,
rx,
))
};
let output_loop_handle = {
let cancel_token = cancel_token.clone();
let py_imports = py_imports.clone();
let active_requests = active_requests.clone();
tokio::spawn(output_loop(
cancel_token,
py_imports,
output,
active_requests,
))
};
Ok(SgLangWorker {
tx,
zmq_context: context,
_input_loop: input_loop_handle,
_output_loop: output_loop_handle,
sglang: sglang_join_handle,
_active_requests: active_requests,
})
}
/// Import all the python packages we'll need.
fn python_imports() -> Imports {
Python::with_gil(|py| {
let pickle_module: PyObject = match py.import("pickle") {
Ok(m) => m.into(),
Err(err) => {
// There is no sglang without python
panic!("Failed to import python 'pickle' module. Is Python installed? {err}");
}
};
// This one is a sanity check
if let Err(err) = py.import("sglang") {
panic!("Failed to import python 'sglang' module. Are we running in the correct venv? {err}");
};
let mod_iostruct: PyObject = match py.import("sglang.srt.managers.io_struct") {
Ok(m) => m.into(),
Err(err) => {
panic!("Failed to import sglang.srt.managers.io_struct. Did sglang change? {err}");
}
};
let rpc_type = mod_iostruct
.getattr(py, "TokenizedGenerateReqInput")
.unwrap();
let mod_sampling: PyObject = match py.import("sglang.srt.sampling.sampling_params") {
Ok(m) => m.into(),
Err(err) => {
panic!(
"Failed to import sglang.srt.sampling.sampling_params. Did sglang change? {err}"
);
}
};
let sampling_params_type: PyObject = mod_sampling.getattr(py, "SamplingParams").unwrap();
Imports {
pickle_module,
sampling_params_type,
rpc_type,
}
})
}
/// Create all the zmq sockets we're going to use.
fn zmq_sockets(sock_code: &str) -> anyhow::Result<Sockets> {
let zmq_context = async_zmq::Context::new();
// Scheduler (rank 0) to receive inputs from us
let input = async_zmq::push(&format!("ipc:///tmp/{sock_code}_input_socket"))?
.with_context(&zmq_context)
.bind()?;
// Use to receive replies from scheduler.
let output = async_zmq::pull(&format!("ipc:///tmp/{sock_code}_output_socket"))?
.with_context(&zmq_context)
.bind()?;
Ok(Sockets {
context: zmq_context,
input,
output,
})
}
/// Start the python sub-process and wait for it to be ready
async fn start_sglang(
model_path: &Path,
node_conf: MultiNodeConfig,
gpu_conf: MultiGPUConfig,
) -> anyhow::Result<(tokio::process::Child, RawFd)> {
// This pipe is how sglang tells us it's ready
let mut pipe_fds: [libc::c_int; 2] = [-1, -1];
unsafe {
let err = libc::pipe2(pipe_fds.as_mut_ptr() as *mut c_int, 0); // libc::O_NONBLOCK);
if err != 0 {
anyhow::bail!("libc::pipe error {err}");
}
}
let sglang_says_hello = pipe_fds[1] as RawFd;
let tp_rank = gpu_conf.tp_rank;
let gpu_id = gpu_conf.gpu_id;
let mut args = vec![
format!("--internal-sglang-process={sglang_says_hello},{tp_rank},{gpu_id}"),
format!("--model-path={}", model_path.display()),
format!("--tensor-parallel-size={}", gpu_conf.tp_size),
format!("--num-nodes={}", node_conf.num_nodes),
format!("--node-rank={}", node_conf.node_rank),
];
if let Some(dist_init_addr) = node_conf.dist_init_addr {
args.push(format!("--dist-init-addr={dist_init_addr}"));
}
let self_path = std::env::current_exe()?;
let mut proc = tokio::process::Command::new(self_path)
.args(args)
.kill_on_drop(false)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = tokio::io::BufReader::new(proc.stdout.take().unwrap());
let stderr = tokio::io::BufReader::new(proc.stderr.take().unwrap());
// Log sglang's stdout
// sglang has (almost?) no output on stdout
tokio::spawn(async move {
let mut lines = stdout.lines();
while let Ok(Some(line)) = lines.next_line().await {
tracing::info!("{LOG_PREFIX}{tp_rank} {line}");
}
});
// Log sglang's stderr
tokio::spawn(async move {
// Remove extra date/time entries from stderr, and print with prefix
let line_re = Regex::new(SGLANG_LOG_RE).unwrap();
let mut lines = stderr.lines();
while let Ok(Some(line)) = lines.next_line().await {
if let Some(cap) = line_re.captures(&line) {
match cap.len() {
2 => {
// No date/time, these are usually errors
tracing::warn!("{LOG_PREFIX}{tp_rank} {line}");
}
3 => {
// Normal log line. Skip Python's date/time
tracing::info!("{LOG_PREFIX}{tp_rank} {}", &cap[2]);
}
x => {
unreachable!("sglang log re only has two capture groups, so {x} entries is impossible");
}
}
}
}
});
let ready_fd = pipe_fds[0] as RawFd;
Ok((proc, ready_fd))
}
async fn wait_for_sglang(tp_rank: u32, pipe_fd: RawFd) -> anyhow::Result<()> {
tracing::info!("Waiting for sglang{tp_rank} to signal that it's ready");
let mut sglang_ready = unsafe { tokio::fs::File::from_raw_fd(pipe_fd) };
let mut buf = [0u8; 128]; // Some pickled JSON, about 90 bytes
let len_read = sglang_ready
.read(&mut buf)
.await
.with_context(|| format!("Failed reading from Rust side of sglang pipe, fd {pipe_fd}",))?;
let received_bytes = &buf[..len_read];
/* received_bytes is pickled JSON:
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
We could unpickle it, but this is faster.
*/
if !received_bytes
.windows(READY_BYTES.len())
.any(|candidate| candidate == READY_BYTES)
{
anyhow::bail!("Expected sglang pipe to signal ready, but did not contain 'ready' bytes");
}
// TODO: warm up the engine
tracing::info!("sglang{tp_rank} is ready");
Ok(())
}
// Stop the sglang process when we stop, and prevent it going zombie.
fn watch_sglang(
cancel_token: CancellationToken,
mut sglang_process: tokio::process::Child,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
cancel_token.cancelled().await;
tokio::select! {
_ = sglang_process.wait() => {
return;
},
_ = tokio::time::sleep(SGLANG_STOP_TIMEOUT) => { }
}
if let Err(err) = sglang_process.start_kill() {
tracing::error!("Failing killing sglang subprocess: {err}");
return;
}
tokio::select! {
_ = sglang_process.wait() => { },
_ = tokio::time::sleep(SGLANG_STOP_TIMEOUT) => {
tracing::warn!("Timeout waiting for sglang sub-process to stop after kill");
}
}
})
}
async fn input_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut input_socket: async_zmq::Push<IntoIter<Vec<u8>>, Vec<u8>>,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
mut rx: tokio::sync::mpsc::Receiver<WorkRequest>,
) {
loop {
let work_request = tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("SgLangWorker.main_loop exit");
break;
}
req = rx.recv() => {
match req {
Some(req) => req,
None => {
tracing::trace!("SgLangWorker input_loop socket closed");
break;
}
}
}
};
let request_id = work_request.request_id;
let token_ids = work_request.request.token_ids.clone();
let temperature: f64 = work_request
.request
.sampling_options
.temperature
.unwrap_or(0.0)
.into();
let max_tokens = work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(DEFAULT_MAX_TOKENS);
tracing::trace!("Received work request: {request_id}");
// Parts that don't change
let (py_request_id, sampling_params) = Python::with_gil(|py| {
let py_temp: PyObject = temperature.into_pyobject(py).unwrap().into();
let py_max_tokens: PyObject = max_tokens.into_pyobject(py).unwrap().into();
let sp_kwargs = [("temperature", py_temp), ("max_new_tokens", py_max_tokens)]
.into_py_dict(py)
.unwrap();
let sampling_params = py_imports
.sampling_params_type
.call(py, (), Some(&sp_kwargs))
.unwrap();
sampling_params
.getattr(py, "normalize")
.unwrap()
.call1(py, (py.None(),))
.unwrap();
let py_request_id: PyObject = PyString::new(py, &request_id).into();
(py_request_id, sampling_params)
});
let pickled_req: Vec<u8> = Python::with_gil(|py| {
let input_text: PyObject = "".into_pyobject(py).unwrap().into();
let input_ids: PyObject = token_ids.into_pyobject(py).unwrap().into();
let image_inputs: PyObject = py.None();
let return_logprob: PyObject = false.into_pyobject(py).unwrap().to_owned().into();
let logprob_start_len: PyObject = 0u32.into_pyobject(py).unwrap().into();
let top_logprobs_num: PyObject = 0u32.into_pyobject(py).unwrap().into();
let stream: PyObject = true.into_pyobject(py).unwrap().to_owned().into();
let rpc_pos_args = (
py_request_id,
input_text,
input_ids,
image_inputs,
sampling_params,
return_logprob,
logprob_start_len,
top_logprobs_num,
stream,
);
//let rpc_kwargs = [].into_py_dict(py).unwrap();
let req = py_imports
.rpc_type
.call(py, rpc_pos_args, None) // Some(&rpc_kwargs))
.unwrap();
let pickle_dumps = py_imports.pickle_module.getattr(py, "dumps").unwrap();
pickle_dumps.call1(py, (req,)).unwrap().extract(py).unwrap()
});
let new_active_request = ActiveRequest {
tx: work_request.response_channel,
max_tokens: max_tokens as i32,
num_output_tokens_so_far: None,
};
active_requests
.lock()
.await
.insert(request_id, new_active_request);
//if let Err(err) = input_socket.send(vec![pickled_req].into()).await {
if let Err(err) = input_socket.send(pickled_req.into()).await {
tracing::error!("Error sending new request to sglang over zmq: {err}");
}
}
}
/// Read from sglang's output zmq socket, find which request it is for and forward over that channel.
async fn output_loop(
cancel_token: CancellationToken,
py_imports: Arc<Imports>,
mut output_socket: async_zmq::Pull,
active_requests: Arc<tokio::sync::Mutex<HashMap<RequestID, ActiveRequest>>>,
) {
loop {
let maybe_bb = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_bb = output_socket.next() => {
maybe_bb
}
};
let mut bb = match maybe_bb {
Some(Ok(b)) => b,
Some(Err(err)) => {
tracing::error!("Error reading from sglang zmq output: {err}");
continue; // hope live eternal
}
None => {
tracing::debug!("zmq output socket closed");
break;
}
};
let frame = bb.remove(0);
let req_out: BatchTokenIDOut = Python::with_gil(|py| {
let pickle_loads = py_imports.pickle_module.getattr(py, "loads").unwrap();
let frame_bytes = PyBytes::new(py, &frame);
let pyobj = pickle_loads.call1(py, (frame_bytes,)).unwrap();
pyobj.extract(py).unwrap()
});
tracing::trace!(?req_out, "from sglang");
let mut remove_after = vec![];
for (idx, req_id) in req_out.rids.into_iter().enumerate() {
let next_total_toks = req_out.decode_ids[idx].len() as i32;
match active_requests.lock().await.get_mut(&req_id) {
Some(active) => {
let previous_total_toks = active
.num_output_tokens_so_far
.unwrap_or(req_out.read_offsets[idx])
as usize;
let sglang_finish_reason = req_out.finished_reasons[idx];
let token_ids: Vec<TokenIdType> = if sglang_finish_reason.is_none() {
req_out.decode_ids[idx][previous_total_toks..].into()
} else {
// Request is over, sglang says so.
// The last token is the eos_token, don't forward it
remove_after.push(req_id.clone());
vec![]
};
let out = LLMEngineOutput {
token_ids,
tokens: None,
text: None,
cum_log_probs: None,
log_probs: None,
finish_reason: sglang_finish_reason.map(|x| x.into()),
};
active.num_output_tokens_so_far = Some(next_total_toks);
let out = if next_total_toks <= active.max_tokens {
Annotated::from_data(out)
} else {
// we exceeded max tokens, this request is over
remove_after.push(req_id.clone());
Annotated::from_data(LLMEngineOutput::length())
};
let _ = active.tx.send(out).await;
}
None => {
// sglang sends the finish response twice, I don't know why
// so only log if it isn't a finished request
if req_out.finished_reasons[idx].is_none() {
tracing::warn!(req_id, "Missing active request");
}
}
}
}
for req_id in remove_after {
let _ = active_requests.lock().await.remove(&req_id);
}
}
}
impl SgLangWorker {
/// Send a request to sglang
pub async fn enqueue_request(&self, r: WorkRequest) -> Result<(), SendError<WorkRequest>> {
self.tx.send(r).await
}
/// Get the sglang sub-process handle, so we can await it and prevent it going zombie.
pub fn take_sglang_handle(&mut self) -> JoinHandle<()> {
self.sglang.take().unwrap()
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment