Commit e97493eb authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: sglang backend for tio (#271)

- Setup venv

```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```

- Build: `cargo build --release --features sglang`

- Run single node (make sure you're in the venv): `./tio out=sglang ~/llm_models/my_model`

- Run Deepseek multi-gpu / multi-node:

Node 1:
```
tio in=http out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 0 --dist-init-addr 10.217.98.122:9876
```

Node 2:
```
tio in=none out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 1 --dist-init-addr 10.217.98.122:9876
```
parent c70de37f
...@@ -1486,6 +1486,12 @@ dependencies = [ ...@@ -1486,6 +1486,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -1501,6 +1507,12 @@ dependencies = [ ...@@ -1501,6 +1507,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.11.0" version = "0.11.0"
...@@ -1693,6 +1705,15 @@ version = "0.3.3" ...@@ -1693,6 +1705,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.17" version = "0.3.17"
...@@ -2339,6 +2360,69 @@ version = "2.28.0" ...@@ -2339,6 +2360,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.38" version = "1.0.38"
...@@ -2677,6 +2761,19 @@ dependencies = [ ...@@ -2677,6 +2761,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.218" version = "1.0.218"
...@@ -3425,12 +3522,15 @@ dependencies = [ ...@@ -3425,12 +3522,15 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"indexmap 2.7.1", "indexmap 2.7.1",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"prometheus", "prometheus",
"pyo3",
"regex", "regex",
"semver", "semver",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -3569,6 +3669,12 @@ version = "0.1.1" ...@@ -3569,6 +3669,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
......
...@@ -1595,6 +1595,12 @@ dependencies = [ ...@@ -1595,6 +1595,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -1616,6 +1622,12 @@ version = "1.70.1" ...@@ -1616,6 +1622,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.11.0" version = "0.11.0"
...@@ -1822,6 +1834,15 @@ version = "0.3.3" ...@@ -1822,6 +1834,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.17" version = "0.3.17"
...@@ -2479,6 +2500,69 @@ version = "2.28.0" ...@@ -2479,6 +2500,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.38" version = "1.0.38"
...@@ -2817,6 +2901,19 @@ dependencies = [ ...@@ -2817,6 +2901,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.218" version = "1.0.218"
...@@ -3597,12 +3694,15 @@ dependencies = [ ...@@ -3597,12 +3694,15 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"indexmap 2.7.1", "indexmap 2.7.1",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"prometheus", "prometheus",
"pyo3",
"regex", "regex",
"semver", "semver",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -3741,6 +3841,12 @@ version = "0.1.1" ...@@ -3741,6 +3841,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
......
...@@ -2309,6 +2309,12 @@ dependencies = [ ...@@ -2309,6 +2309,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -2345,6 +2351,12 @@ version = "1.70.1" ...@@ -2345,6 +2351,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.11.0" version = "0.11.0"
...@@ -2616,6 +2628,15 @@ version = "0.3.3" ...@@ -2616,6 +2628,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metal" name = "metal"
version = "0.27.0" version = "0.27.0"
...@@ -3632,6 +3653,69 @@ dependencies = [ ...@@ -3632,6 +3653,69 @@ dependencies = [
"reborrow", "reborrow",
] ]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
...@@ -4171,6 +4255,19 @@ dependencies = [ ...@@ -4171,6 +4255,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.218" version = "1.0.218"
...@@ -5181,6 +5278,7 @@ dependencies = [ ...@@ -5181,6 +5278,7 @@ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
"async-trait", "async-trait",
"async_zmq",
"axum 0.8.1", "axum 0.8.1",
"blake3", "blake3",
"bs62", "bs62",
...@@ -5193,13 +5291,16 @@ dependencies = [ ...@@ -5193,13 +5291,16 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"indexmap 2.7.1", "indexmap 2.7.1",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
"prometheus", "prometheus",
"pyo3",
"regex", "regex",
"semver", "semver",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum 0.27.1", "strum 0.27.1",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -5338,6 +5439,12 @@ version = "0.1.1" ...@@ -5338,6 +5439,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "unsafe-libyaml" name = "unsafe-libyaml"
version = "0.2.11" version = "0.2.11"
......
...@@ -22,6 +22,7 @@ homepage = "https://github.com/triton-inference-server/triton_distributed" ...@@ -22,6 +22,7 @@ homepage = "https://github.com/triton-inference-server/triton_distributed"
[features] [features]
mistralrs = ["triton-distributed-llm/mistralrs"] mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
cuda = ["triton-distributed-llm/cuda"] cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"] metal = ["triton-distributed-llm/metal"]
......
...@@ -62,3 +62,12 @@ The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has ...@@ -62,3 +62,12 @@ The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has
Run `tio --help` for more options. Run `tio --help` for more options.
## sglang
```
uv venv
source .venv/bin/activate
uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
use triton_distributed_llm::{ use triton_distributed_llm::{
backend::ExecutionContext, backend::ExecutionContext,
...@@ -29,6 +30,8 @@ use triton_distributed_llm::{ ...@@ -29,6 +30,8 @@ use triton_distributed_llm::{
use triton_distributed_runtime::{component::Client, protocols::Endpoint, DistributedRuntime}; use triton_distributed_runtime::{component::Client, protocols::Endpoint, DistributedRuntime};
mod input; mod input;
#[cfg(feature = "sglang")]
mod net;
mod opt; mod opt;
mod output; mod output;
pub use opt::{Input, Output}; pub use opt::{Input, Output};
...@@ -58,6 +61,53 @@ pub struct Flags { ...@@ -58,6 +61,53 @@ pub struct Flags {
/// The name of the model we are serving /// The name of the model we are serving
#[arg(long)] #[arg(long)]
pub model_name: Option<String>, pub model_name: Option<String>,
/// sglang only
///
/// How many GPUs to use at once, total across all nodes.
/// This must divide by num_nodes, and each node must use the same number of GPUs.
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub tensor_parallel_size: u32,
/// sglang only
///
/// Use GPUs from this ID upwards.
/// If your machine has four GPUs but the first two (0 and 1) are in use,
/// pass --base-gpu-id 2 to use the third GPU (and up, if tensor_parallel_size > 1)
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..256))]
pub base_gpu_id: u32,
/// sglang only
///
/// How many nodes/hosts to use
#[arg(long, default_value = "1", value_parser = clap::value_parser!(u32).range(1..256))]
pub num_nodes: u32,
/// sglang only
///
/// This nodes' unique ID, running from 0 to num_nodes.
#[arg(long, default_value = "0", value_parser = clap::value_parser!(u32).range(0..255))]
pub node_rank: u32,
/// sglang only
///
/// The Torch Distributed init method address, in format <host>:<port>.
/// It becomes "tcp://<host>:<port>" when given to torch.distributed.init_process_group.
/// This expects to use the nccl backend (transparently to us here).
/// All nodes must use the same dist_init_addr, which is node_rank == 0's address.
#[arg(long)]
pub dist_init_addr: Option<String>,
/// Internal use only.
/// Start the sglang Python sub-process.
/// The params in the tuple are:
/// - the fd of the write end of a pipe where sglang will signal that it's ready.
/// - the node rank (0 for first host, 1 for second host, etc)
/// - the workers' rank (globally unique)
/// - the GPU to use (locally unique)
#[arg(long)]
#[clap(hide = true, value_parser = parse_sglang_flags)]
pub internal_sglang_process: Option<SgLangFlags>,
} }
pub enum EngineConfig { pub enum EngineConfig {
...@@ -79,11 +129,36 @@ pub enum EngineConfig { ...@@ -79,11 +129,36 @@ pub enum EngineConfig {
}, },
} }
#[derive(Debug, Clone, Copy)]
pub struct SgLangFlags {
pub pipe_fd: u32,
pub tp_rank: u32,
pub gpu_id: u32,
}
fn parse_sglang_flags(s: &str) -> Result<SgLangFlags, String> {
let nums: Vec<u32> = s
.split(',')
.map(u32::from_str)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
if nums.len() != 3 {
return Err("Need exactly 3 numbers".into());
}
Ok(SgLangFlags {
pipe_fd: nums[0],
tp_rank: nums[1],
gpu_id: nums[2],
})
}
pub async fn run( pub async fn run(
runtime: triton_distributed_runtime::Runtime, runtime: triton_distributed_runtime::Runtime,
in_opt: Input, in_opt: Input,
out_opt: Output, out_opt: Output,
flags: Flags, flags: Flags,
#[allow(unused_variables)] zmq_socket_prefix: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let cancel_token = runtime.primary_token(); let cancel_token = runtime.primary_token();
...@@ -109,6 +184,9 @@ pub async fn run( ...@@ -109,6 +184,9 @@ pub async fn run(
Some(_) | None => None, Some(_) | None => None,
}; };
#[cfg(feature = "sglang")]
let mut extra = None; // sglang sub-process
// Create the engine matching `out` // Create the engine matching `out`
let engine_config = match out_opt { let engine_config = match out_opt {
Output::EchoFull => { Output::EchoFull => {
...@@ -174,6 +252,49 @@ pub async fn run( ...@@ -174,6 +252,49 @@ pub async fn run(
.await?, .await?,
} }
} }
#[cfg(feature = "sglang")]
Output::SgLang => {
use triton_distributed_llm::engines::sglang;
let Some(model_path) = model_path else {
anyhow::bail!("out=sglang requires flag --model-path=<full-path-to-model-dir>");
};
if !model_path.is_dir() {
anyhow::bail!("`--model-path should point at a HuggingFace repo checkout");
}
// Safety: Earlier we build maybe_card from model_path, which we checked right above
let card = maybe_card.clone().unwrap();
let Some(sock_prefix) = zmq_socket_prefix else {
anyhow::bail!("sglang requires zmq_socket_prefix");
};
let node_conf = sglang::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
};
if node_conf.num_nodes > 1 {
if let Ok(Some(if_name)) = net::get_primary_interface().await {
tracing::info!("If you see 'gloo' errors from sglang try setting these environment variables:");
tracing::info!("export GLOO_SOCKET_IFNAME={if_name}");
tracing::info!("export NCCL_SOCKET_IFNAME={if_name}");
}
}
let (engine, sglang_process) = sglang::make_engine(
cancel_token.clone(),
&model_path,
&sock_prefix,
node_conf,
flags.tensor_parallel_size,
flags.base_gpu_id,
)
.await?;
extra = Some(sglang_process);
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
}; };
match in_opt { match in_opt {
...@@ -186,6 +307,19 @@ pub async fn run( ...@@ -186,6 +307,19 @@ pub async fn run(
Input::Endpoint(path) => { Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?; crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
} }
Input::None => {
// Multi-node setup. The engine sub-process has been started and is talking
// to it's node_rank 0 controller. We do nothing.
// TODO: Acquire an etcd lease, we are running
cancel_token.cancelled().await;
}
}
#[cfg(feature = "sglang")]
// Allow engines to ask main thread to wait on an extra future.
// sglang uses this to shut down sub-process
if let Some(extra) = extra {
extra.await?;
} }
Ok(()) Ok(())
......
...@@ -39,11 +39,52 @@ const DEFAULT_OUT: Output = Output::MistralRs; ...@@ -39,11 +39,52 @@ const DEFAULT_OUT: Output = Output::MistralRs;
#[cfg(not(feature = "mistralrs"))] #[cfg(not(feature = "mistralrs"))]
const DEFAULT_OUT: Output = Output::EchoFull; const DEFAULT_OUT: Output = Output::EchoFull;
const USAGE: &str = "USAGE: tio in=[http|text] out=[mistralrs|echo_full] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>]"; const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
// Call sub-processes before starting the Runtime machinery
// For anything except sub-process starting try_parse_from will error.
if let Ok(flags) = tio::Flags::try_parse_from(env::args()) {
#[allow(unused_variables)]
if let Some(sglang_flags) = flags.internal_sglang_process {
let Some(model_path) = flags.model_path_flag.as_ref() else {
anyhow::bail!("sglang subprocess requires --model-path");
};
if !model_path.is_dir() {
anyhow::bail!("sglang subprocess requires model path to be a directory containing the safetensors files");
}
if cfg!(feature = "sglang") {
#[cfg(feature = "sglang")]
{
use triton_distributed_llm::engines::sglang;
let gpu_config = sglang::MultiGPUConfig {
tp_size: flags.tensor_parallel_size,
tp_rank: sglang_flags.tp_rank,
gpu_id: sglang_flags.gpu_id,
};
let node_config = sglang::MultiNodeConfig {
num_nodes: flags.num_nodes,
node_rank: flags.node_rank,
dist_init_addr: flags.dist_init_addr,
};
return sglang::run_subprocess(
ZMQ_SOCKET_PREFIX,
model_path,
sglang_flags.pipe_fd as std::os::fd::RawFd,
node_config,
gpu_config,
);
}
} else {
panic!("Rebuild with --features=sglang");
}
}
}
// max_worker_threads and max_blocking_threads from env vars or config file. // max_worker_threads and max_blocking_threads from env vars or config file.
let rt_config = triton_distributed_runtime::RuntimeConfig::from_settings()?; let rt_config = triton_distributed_runtime::RuntimeConfig::from_settings()?;
...@@ -103,5 +144,12 @@ async fn tio_wrapper(runtime: triton_distributed_runtime::Runtime) -> anyhow::Re ...@@ -103,5 +144,12 @@ async fn tio_wrapper(runtime: triton_distributed_runtime::Runtime) -> anyhow::Re
.chain(env::args().skip(non_flag_params)), .chain(env::args().skip(non_flag_params)),
)?; )?;
tio::run(runtime, in_opt, out_opt, flags).await tio::run(
runtime,
in_opt,
out_opt,
flags,
Some(ZMQ_SOCKET_PREFIX.to_string()),
)
.await
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures_util::TryStreamExt;
use netlink_packet_route::address::AddressAttribute;
use netlink_packet_route::link::LinkLayerType;
use netlink_packet_route::link::State as LinkState;
use netlink_packet_route::link::{LinkAttribute, LinkMessage};
use netlink_packet_route::AddressFamily;
use std::collections::HashSet;
use std::collections::VecDeque;
use std::{collections::HashMap, error::Error};
pub async fn get_primary_interface() -> Result<Option<String>, LinkDataError> {
let mut candidates: VecDeque<String> = get_ipv4_interface_links()
.await?
.into_iter()
.filter(|(k, v)| v.is_ethernet() && v.link_is_up() && v.has_carrier() && k.starts_with("e"))
.map(|(k, _)| k)
.collect();
Ok(candidates.pop_front())
}
#[derive(Clone, Debug)]
// Most of the fields are Option<T> because the netlink protocol allows them
// to be absent (even though we have no reason to believe they'd ever actually
// be missing).
struct InterfaceLinkData {
link_type: LinkLayerType,
state: Option<LinkState>,
has_carrier: bool,
}
impl InterfaceLinkData {
pub fn link_is_up(&self) -> bool {
self.state
.map(|state| matches!(state, LinkState::Up))
.unwrap_or(false)
}
pub fn is_ethernet(&self) -> bool {
matches!(self.link_type, LinkLayerType::Ether)
}
pub fn has_carrier(&self) -> bool {
self.has_carrier
}
}
impl From<LinkMessage> for InterfaceLinkData {
fn from(link_message: LinkMessage) -> Self {
let link_type = link_message.header.link_layer_type;
let state = link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::OperState(state) => Some(*state),
_ => None,
});
let has_carrier = link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::Carrier(1) => Some(true),
_ => None,
})
.unwrap_or(false);
InterfaceLinkData {
link_type,
state,
has_carrier,
}
}
}
#[derive(Debug)]
pub struct LinkDataError {
kind: LinkDataErrorKind,
interface: Option<String>,
}
impl LinkDataError {
fn connection(connection_error: std::io::Error) -> Self {
let kind = LinkDataErrorKind::Connection(connection_error);
let interface = None;
Self { kind, interface }
}
fn communication(communication_error: rtnetlink::Error) -> Self {
let kind = LinkDataErrorKind::Communication(communication_error);
let interface = None;
Self { kind, interface }
}
}
impl std::fmt::Display for LinkDataError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let err_message = "could not get interface link data";
if let Some(interface) = self.interface.as_ref() {
write!(f, "{err_message} for {interface}")
} else {
write!(f, "{err_message}")
}
}
}
impl Error for LinkDataError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self.kind {
LinkDataErrorKind::Connection(ref e) => Some(e),
LinkDataErrorKind::Communication(ref e) => Some(e),
}
}
}
#[derive(Debug)]
pub enum LinkDataErrorKind {
Connection(std::io::Error),
Communication(rtnetlink::Error),
}
// Retrieve the link data (state, MTU, etc.) for all interfaces, and return
// them as a HashMap keyed by interface name. This is roughly equivalent to `ip
// link show` since we're using the same netlink interface under the hood as
// that command.
async fn get_ipv4_interface_links() -> Result<HashMap<String, InterfaceLinkData>, LinkDataError> {
let (netlink_connection, rtnetlink_handle, _receiver) =
rtnetlink::new_connection().map_err(LinkDataError::connection)?;
// We have to spawn off the netlink connection because of the architecture
// of `netlink_proto::Connection`, which runs in the background and owns
// the socket. We communicate with it via channel messages, and it will exit
// when both `rtnetlink_handle` and `_receiver` go out of scope.
tokio::spawn(netlink_connection);
let address_handle = rtnetlink_handle.address().get().execute();
let ipv4s: HashSet<String> = address_handle
.try_filter_map(|addr_message| async move {
if matches!(addr_message.header.family, AddressFamily::Inet) {
Ok(addr_message
.attributes
.into_iter()
.find(|attr| matches!(attr, AddressAttribute::Label(_)))
.and_then(|x| match x {
AddressAttribute::Label(label) => Some(label),
_ => None,
}))
} else {
Ok(None)
}
})
.try_collect()
.await
.map_err(LinkDataError::communication)?;
let link_handle = rtnetlink_handle.link().get().execute();
link_handle
.try_filter_map(|link_message| async {
let maybe_interface_data = match extract_interface_name(&link_message) {
Some(interface_name) => {
if ipv4s.contains(&interface_name) {
Some((interface_name, InterfaceLinkData::from(link_message)))
} else {
None
}
}
None => {
let idx = link_message.header.index;
eprintln!(
"Network interface with index {idx} doesn't have a name (no IfName attribute)"
);
None
}
};
Ok(maybe_interface_data)
})
.try_collect()
.await
.map_err(LinkDataError::communication)
}
fn extract_interface_name(link_message: &LinkMessage) -> Option<String> {
link_message
.attributes
.iter()
.find_map(|attribute| match attribute {
LinkAttribute::IfName(name) => Some(name.clone()),
_ => None,
})
}
...@@ -26,6 +26,11 @@ pub enum Input { ...@@ -26,6 +26,11 @@ pub enum Input {
/// Pull requests from a namespace/component/endpoint path. /// Pull requests from a namespace/component/endpoint path.
Endpoint(String), Endpoint(String),
/// Start the engine but don't provide any way to talk to it.
/// For multi-node sglang, where the engine connects directly
/// to the co-ordinator via torch distributed / nccl.
None,
} }
impl TryFrom<&str> for Input { impl TryFrom<&str> for Input {
...@@ -35,6 +40,7 @@ impl TryFrom<&str> for Input { ...@@ -35,6 +40,7 @@ impl TryFrom<&str> for Input {
match s { match s {
"http" => Ok(Input::Http), "http" => Ok(Input::Http),
"text" => Ok(Input::Text), "text" => Ok(Input::Text),
"none" => Ok(Input::None),
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => { endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap(); let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Input::Endpoint(path.to_string())) Ok(Input::Endpoint(path.to_string()))
...@@ -50,6 +56,7 @@ impl fmt::Display for Input { ...@@ -50,6 +56,7 @@ impl fmt::Display for Input {
Input::Http => "http", Input::Http => "http",
Input::Text => "text", Input::Text => "text",
Input::Endpoint(path) => path, Input::Endpoint(path) => path,
Input::None => "none",
}; };
write!(f, "{s}") write!(f, "{s}")
} }
...@@ -68,6 +75,10 @@ pub enum Output { ...@@ -68,6 +75,10 @@ pub enum Output {
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle /// Run inference on a model in a GGUF file using mistralrs w/ candle
MistralRs, MistralRs,
#[cfg(feature = "sglang")]
/// Run inference using sglang
SgLang,
} }
impl TryFrom<&str> for Output { impl TryFrom<&str> for Output {
...@@ -78,6 +89,9 @@ impl TryFrom<&str> for Output { ...@@ -78,6 +89,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
"mistralrs" => Ok(Output::MistralRs), "mistralrs" => Ok(Output::MistralRs),
#[cfg(feature = "sglang")]
"sglang" => Ok(Output::SgLang),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore), "echo_core" => Ok(Output::EchoCore),
...@@ -97,6 +111,9 @@ impl fmt::Display for Output { ...@@ -97,6 +111,9 @@ impl fmt::Display for Output {
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => "mistralrs", Output::MistralRs => "mistralrs",
#[cfg(feature = "sglang")]
Output::SgLang => "sglang",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core", Output::EchoCore => "echo_core",
......
...@@ -1581,6 +1581,12 @@ dependencies = [ ...@@ -1581,6 +1581,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -1602,6 +1608,12 @@ version = "1.70.1" ...@@ -1602,6 +1608,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.11.0" version = "0.11.0"
...@@ -1815,6 +1827,15 @@ version = "0.3.3" ...@@ -1815,6 +1827,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.17" version = "0.3.17"
...@@ -2461,6 +2482,69 @@ version = "2.28.0" ...@@ -2461,6 +2482,69 @@ version = "2.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.96",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.96",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.38" version = "1.0.38"
...@@ -2800,6 +2884,19 @@ dependencies = [ ...@@ -2800,6 +2884,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.217" version = "1.0.217"
...@@ -3554,12 +3651,15 @@ dependencies = [ ...@@ -3554,12 +3651,15 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"indexmap 2.7.0", "indexmap 2.7.0",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"prometheus", "prometheus",
"pyo3",
"regex", "regex",
"semver", "semver",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -3698,6 +3798,12 @@ version = "0.1.1" ...@@ -3698,6 +3798,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
......
...@@ -1634,6 +1634,12 @@ version = "1.70.1" ...@@ -1634,6 +1634,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.11.0" version = "0.11.0"
...@@ -2927,6 +2933,19 @@ dependencies = [ ...@@ -2927,6 +2933,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.217" version = "1.0.217"
...@@ -3681,12 +3700,15 @@ dependencies = [ ...@@ -3681,12 +3700,15 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"indexmap 2.7.1", "indexmap 2.7.1",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"prometheus", "prometheus",
"pyo3",
"regex", "regex",
"semver", "semver",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum", "strum",
"thiserror 2.0.11", "thiserror 2.0.11",
......
...@@ -2374,6 +2374,12 @@ dependencies = [ ...@@ -2374,6 +2374,12 @@ dependencies = [
"web-time", "web-time",
] ]
[[package]]
name = "indoc"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5"
[[package]] [[package]]
name = "inlinable_string" name = "inlinable_string"
version = "0.1.15" version = "0.1.15"
...@@ -2429,6 +2435,12 @@ version = "1.70.1" ...@@ -2429,6 +2435,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "iter-read"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071ed4cc1afd86650602c7b11aa2e1ce30762a1c27193201cb5cee9c6ebb1294"
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.10.5" version = "0.10.5"
...@@ -2715,6 +2727,15 @@ version = "0.3.3" ...@@ -2715,6 +2727,15 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b"
[[package]]
name = "memoffset"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "metal" name = "metal"
version = "0.27.0" version = "0.27.0"
...@@ -3755,6 +3776,69 @@ dependencies = [ ...@@ -3755,6 +3776,69 @@ dependencies = [
"reborrow", "reborrow",
] ]
[[package]]
name = "pyo3"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
dependencies = [
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset",
"once_cell",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
"pyo3-macros",
"unindent",
]
[[package]]
name = "pyo3-build-config"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
dependencies = [
"once_cell",
"target-lexicon",
]
[[package]]
name = "pyo3-ffi"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
dependencies = [
"libc",
"pyo3-build-config",
]
[[package]]
name = "pyo3-macros"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.98",
]
[[package]]
name = "pyo3-macros-backend"
version = "0.23.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "qoi" name = "qoi"
version = "0.4.1" version = "0.4.1"
...@@ -4427,6 +4511,19 @@ dependencies = [ ...@@ -4427,6 +4511,19 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-pickle"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b641fdc8bcf2781ee78b30c599700d64ad4f412976143e4c5d0b9df906bb4843"
dependencies = [
"byteorder",
"iter-read",
"num-bigint",
"num-traits",
"serde",
]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.217" version = "1.0.217"
...@@ -5434,6 +5531,7 @@ dependencies = [ ...@@ -5434,6 +5531,7 @@ dependencies = [
"anyhow", "anyhow",
"async-stream", "async-stream",
"async-trait", "async-trait",
"async_zmq",
"axum 0.8.1", "axum 0.8.1",
"blake3", "blake3",
"bs62", "bs62",
...@@ -5448,17 +5546,20 @@ dependencies = [ ...@@ -5448,17 +5546,20 @@ dependencies = [
"indexmap 2.7.1", "indexmap 2.7.1",
"insta", "insta",
"itertools 0.14.0", "itertools 0.14.0",
"libc",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
"prometheus", "prometheus",
"proptest", "proptest",
"pyo3",
"regex", "regex",
"reqwest", "reqwest",
"rstest", "rstest",
"semver", "semver",
"sentencepiece", "sentencepiece",
"serde", "serde",
"serde-pickle",
"serde_json", "serde_json",
"strum 0.27.1", "strum 0.27.1",
"tempfile", "tempfile",
...@@ -5610,6 +5711,12 @@ version = "0.1.1" ...@@ -5610,6 +5711,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
[[package]]
name = "unindent"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce"
[[package]] [[package]]
name = "unsafe-libyaml" name = "unsafe-libyaml"
version = "0.2.11" version = "0.2.11"
......
...@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"] ...@@ -34,6 +34,7 @@ mistralrs = ["dep:mistralrs"]
metal = ["mistralrs/metal"] metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"] cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"] sentencepiece = ["dep:sentencepiece"]
sglang = ["dep:async_zmq"]
[workspace.dependencies] [workspace.dependencies]
# local or crates.io # local or crates.io
...@@ -81,6 +82,7 @@ xxhash-rust = { workspace = true } ...@@ -81,6 +82,7 @@ xxhash-rust = { workspace = true }
strum = { workspace = true } strum = { workspace = true }
blake3 = "1" blake3 = "1"
regex = "1"
# protocols # protocols
chrono = { version = "0.4", default-features = false, features = [ chrono = { version = "0.4", default-features = false, features = [
...@@ -91,7 +93,6 @@ chrono = { version = "0.4", default-features = false, features = [ ...@@ -91,7 +93,6 @@ chrono = { version = "0.4", default-features = false, features = [
"serde", "serde",
] } ] }
serde_json = { version = "1" } serde_json = { version = "1" }
regex = "1"
unicode-segmentation = "1.12" unicode-segmentation = "1.12"
# http-service # http-service
...@@ -103,6 +104,17 @@ either = { version = "1.13" } ...@@ -103,6 +104,17 @@ either = { version = "1.13" }
indexmap = { version = "2.6" } indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true } mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true }
# sglang
async_zmq = { version = "0.4.0", optional = true }
libc = "0.2"
pyo3 = { version = "0.23.3", default-features = false, features = [
"macros",
"experimental-async",
"experimental-inspect",
"py-clone",
] }
serde-pickle = "1.2.0"
# tokenizers # tokenizers
tokenizers = { version = "0.21.0", default-features = false, features = [ tokenizers = { version = "0.21.0", default-features = false, features = [
"onig", "onig",
......
...@@ -15,3 +15,6 @@ ...@@ -15,3 +15,6 @@
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
pub mod mistralrs; pub mod mistralrs;
#[cfg(feature = "sglang")]
pub mod sglang;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use std::sync::Arc;
use crate::backend::ExecutionContext;
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::CancellationToken;
mod worker;
mod engine;
use engine::SgLangEngine;
mod subprocess;
pub use subprocess::run_subprocess;
pub async fn make_engine(
cancel_token: CancellationToken,
// Full path to the model directory
model_path: &Path,
// Unique string to name zmq sockets
sock_code: &str,
// Multi node settings:
// - num_nodes: How many nodes/hosts we are using
// - node_rank: Unique consecutive int starting at 0 to identify this node
// - dist_init_addr: Torch Distributed init method addr:port
node_conf: MultiNodeConfig,
// How many GPUs to use
tensor_parallel_size: u32,
// The base GPU ID to start allocating GPUs from
base_gpu_id: u32,
) -> pipeline_error::Result<(ExecutionContext, tokio::task::JoinHandle<()>)> {
let mut engine = SgLangEngine::new(
cancel_token,
sock_code,
model_path,
node_conf,
tensor_parallel_size,
base_gpu_id,
)
.await?;
let sglang_process = engine.take_sglang_worker_handle();
let engine: ExecutionContext = Arc::new(engine);
Ok((engine, sglang_process))
}
#[derive(Debug, Clone, Copy)]
pub struct MultiGPUConfig {
/// How many GPUs we are using / how many processes
pub tp_size: u32,
/// Tensor Parallel Rank. Must be unique across all nodes and GPUs.
pub tp_rank: u32,
/// GPU ID. Which GPU to run on. In single-node setup this is the same as tp_rank.
pub gpu_id: u32,
}
impl Default for MultiGPUConfig {
fn default() -> Self {
MultiGPUConfig {
tp_size: 1,
tp_rank: 0,
gpu_id: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiNodeConfig {
/// How many nodes / hosts we are using
pub num_nodes: u32,
/// Unique consecutive integer to identify this node
pub node_rank: u32,
/// host:port of head / control node
pub dist_init_addr: Option<String>,
}
impl Default for MultiNodeConfig {
fn default() -> Self {
MultiNodeConfig {
num_nodes: 1,
node_rank: 0,
dist_init_addr: None,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use async_stream::stream;
use async_trait::async_trait;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::runtime::CancellationToken;
use crate::engines::sglang::MultiNodeConfig;
pub struct SgLangEngine {
cancel_token: CancellationToken,
worker: super::worker::SgLangWorker,
}
impl SgLangEngine {
pub async fn new(
cancel_token: CancellationToken,
sock_code: &str,
model_path: &Path,
node_conf: MultiNodeConfig,
tensor_parallel_size: u32,
base_gpu_id: u32,
) -> anyhow::Result<Self> {
let w = super::worker::start(
cancel_token.clone(),
sock_code,
model_path,
node_conf,
tensor_parallel_size,
base_gpu_id,
)
.await?;
let engine = SgLangEngine {
cancel_token,
worker: w,
};
Ok(engine)
}
pub fn take_sglang_worker_handle(&mut self) -> tokio::task::JoinHandle<()> {
self.worker.take_sglang_handle()
}
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for SgLangEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(128);
let work_req = super::worker::WorkRequest {
request_id: context.id().to_string(),
request,
response_channel: resp_tx,
};
self.worker.enqueue_request(work_req).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_resp_rx = resp_rx.recv() => {
match maybe_resp_rx {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use pyo3::{types::IntoPyDict, Python};
use std::{os::fd::RawFd, path::Path};
const PY_START_ENGINE: &std::ffi::CStr = cr#"
from multiprocessing.connection import Connection
import signal
import tempfile
import logging
from sglang.srt.server_args import ServerArgs, PortArgs
import sglang as sgl
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.entrypoints.engine import _set_envs_and_config
server_args = ServerArgs(
model_path=f"{model_path}",
enable_metrics = False,
log_level = "debug",
log_requests = True,
tp_size = int(tp_size_str),
# Multi-node
dist_init_addr = dist_init_addr if dist_init_addr != "" else None,
nnodes = int(nnodes_str),
node_rank = int(node_rank_str),
)
logging.basicConfig(
level="DEBUG",
force=True,
datefmt="%Y-%m-%d %H:%M:%S",
format=f"[%(asctime)s] %(message)s",
)
_set_envs_and_config(server_args)
logging.debug(server_args)
ipc_path = f"ipc:///tmp/{socket_id}";
# These must match worker.rs zmq_sockets, which is the other side
port_args = PortArgs(
# we don't use this one so use anything
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
# Us -> sglang
scheduler_input_ipc_name=f"{ipc_path}_input_socket",
# sglang -> us
detokenizer_ipc_name=f"{ipc_path}_output_socket",
# The port for nccl initialization (torch.dist), which we don't use
nccl_port=9876,
)
# Rank must be globally unique across nodes
tp_rank = int(tp_rank_str)
# See nvidia-smi for GPU IDs, they run 0,1,2,etc.
# In a single-node setup this is the same as rank
gpu_id = int(gpu_id_str)
pipe_fd_int = int(pipe_fd)
writer = Connection(handle=pipe_fd_int, readable=False, writable=True)
run_scheduler_process(server_args, port_args, gpu_id, tp_rank, None, writer)
"#;
/// Start the Python sglang engine that listens on zmq socket
/// This is called by running `nio --internal-sglang-process
/// This does not return until the subprocess exits.
pub fn run_subprocess(
// The prefix to put on the zmq socket names
socket_id: &str,
// Directory containing an HF repo with safetensors files, tokenizer, etc
model_path: &Path,
// The write half of a pipe, where sglang will signal when it's ready
notify_pipe_fd: RawFd,
// Multi node. Usually Default::default
node_config: super::MultiNodeConfig,
// Multi GPU. Usually Default::default
gpu_config: super::MultiGPUConfig,
) -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python(); // or enable feature "auto-initialize"
let dir = model_path.display().to_string();
Python::with_gil(|py| {
let locals = [
("socket_id", socket_id),
("model_path", dir.as_str()),
("pipe_fd", &notify_pipe_fd.to_string()),
// to_string because slice must all be the same type
("tp_size_str", &gpu_config.tp_size.to_string()),
("tp_rank_str", &gpu_config.tp_rank.to_string()),
("gpu_id_str", &gpu_config.gpu_id.to_string()),
("nnodes_str", &node_config.num_nodes.to_string()),
("node_rank_str", &node_config.node_rank.to_string()),
(
"dist_init_addr",
&node_config.dist_init_addr.unwrap_or_default().to_string(),
),
]
.into_py_dict(py)
.unwrap();
if let Err(err) = py.run(PY_START_ENGINE, None, Some(&locals)) {
anyhow::bail!("sglang engine run error: {err}");
}
tracing::info!("sglang subprocess exit");
Ok(())
})
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment