"classification/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "43e508f4b990d1316838ec5d323b7a99f24626f4"
Commit e584e96f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: llama.cpp engine for tio (#298)

Docs in README
parent b20ef999
...@@ -443,6 +443,29 @@ version = "1.6.0" ...@@ -443,6 +443,29 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.11.0",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
"which",
]
[[package]] [[package]]
name = "bindgen_cuda" name = "bindgen_cuda"
version = "0.1.5" version = "0.1.5"
...@@ -686,6 +709,15 @@ dependencies = [ ...@@ -686,6 +709,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -743,6 +775,17 @@ dependencies = [ ...@@ -743,6 +775,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.30" version = "4.5.30"
...@@ -784,6 +827,15 @@ version = "0.7.4" ...@@ -784,6 +827,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "color_quant" name = "color_quant"
version = "1.1.0" version = "1.1.0"
...@@ -1398,6 +1450,26 @@ dependencies = [ ...@@ -1398,6 +1450,26 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.2" version = "1.0.2"
...@@ -1513,6 +1585,15 @@ dependencies = [ ...@@ -1513,6 +1585,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "find_cuda_helper"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
dependencies = [
"glob",
]
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
version = "0.5.7" version = "0.5.7"
...@@ -1996,6 +2077,15 @@ dependencies = [ ...@@ -1996,6 +2077,15 @@ dependencies = [
"ureq", "ureq",
] ]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.2.0"
...@@ -2509,6 +2599,12 @@ version = "1.5.0" ...@@ -2509,6 +2599,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "lebe" name = "lebe"
version = "0.5.2" version = "0.5.2"
...@@ -2559,6 +2655,33 @@ version = "0.7.4" ...@@ -2559,6 +2655,33 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llama-cpp-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [
"bindgen",
"cc",
"cmake",
"find_cuda_helper",
"glob",
"walkdir",
]
[[package]] [[package]]
name = "llguidance" name = "llguidance"
version = "0.4.1" version = "0.4.1"
...@@ -2571,7 +2694,7 @@ dependencies = [ ...@@ -2571,7 +2694,7 @@ dependencies = [
"instant", "instant",
"referencing", "referencing",
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"toktrie 0.1.0", "toktrie 0.1.0",
...@@ -2889,7 +3012,7 @@ dependencies = [ ...@@ -2889,7 +3012,7 @@ dependencies = [
"regex", "regex",
"regex-automata 0.4.9", "regex-automata 0.4.9",
"reqwest", "reqwest",
"rustc-hash", "rustc-hash 2.1.1",
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
...@@ -3810,7 +3933,7 @@ dependencies = [ ...@@ -3810,7 +3933,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -3828,7 +3951,7 @@ dependencies = [ ...@@ -3828,7 +3951,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -4170,6 +4293,12 @@ version = "0.1.24" ...@@ -4170,6 +4293,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -5186,7 +5315,7 @@ dependencies = [ ...@@ -5186,7 +5315,7 @@ dependencies = [
"anyhow", "anyhow",
"bytemuck", "bytemuck",
"bytemuck_derive", "bytemuck_derive",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
] ]
...@@ -5211,7 +5340,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497 ...@@ -5211,7 +5340,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497
dependencies = [ dependencies = [
"anyhow", "anyhow",
"log", "log",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"tokenizers", "tokenizers",
...@@ -5467,6 +5596,7 @@ dependencies = [ ...@@ -5467,6 +5596,7 @@ dependencies = [
"indexmap 2.7.1", "indexmap 2.7.1",
"itertools 0.14.0", "itertools 0.14.0",
"libc", "libc",
"llama-cpp-2",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
...@@ -5938,6 +6068,18 @@ version = "0.1.8" ...@@ -5938,6 +6068,18 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -24,6 +24,7 @@ license = "Apache-2.0" ...@@ -24,6 +24,7 @@ license = "Apache-2.0"
[features] [features]
mistralrs = ["triton-distributed-llm/mistralrs"] mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"] sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
llamacpp = ["triton-distributed-llm/llamacpp"]
cuda = ["triton-distributed-llm/cuda"] cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"] metal = ["triton-distributed-llm/metal"]
......
...@@ -48,12 +48,12 @@ curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages ...@@ -48,12 +48,12 @@ curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages
Node 1: Node 1:
``` ```
tio in=http out=tdr://ns/backend/mistralrs tio in=http out=tdr://llama3B_pool
``` ```
Node 2: Node 2:
``` ```
tio in=tdr://ns/backend/mistralrs out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct tio in=tdr://llama3B_pool out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct
``` ```
This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time. This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.
...@@ -64,6 +64,8 @@ Run `tio --help` for more options. ...@@ -64,6 +64,8 @@ Run `tio --help` for more options.
## sglang ## sglang
1. Setup the python virtual env:
``` ```
uv venv uv venv
source .venv/bin/activate source .venv/bin/activate
...@@ -71,3 +73,36 @@ uv pip install pip ...@@ -71,3 +73,36 @@ uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
``` ```
2. Build
```
cargo build --release --features sglang
```
3. Run
Any example above using `out=sglang` will work, but our sglang backend is also multi-gpu and multi-node.
Node 1:
```
tio in=http out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 0 --dist-init-addr 10.217.98.122:9876
```
Node 2:
```
tio in=none out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 1 --dist-init-addr 10.217.98.122:9876
```
## llama_cpp
- `cargo build --release --features llamacpp,cuda`
- `tio out=llama_cpp --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/`
The extra `--model-config` flag is because:
- llama_cpp only runs GGUF
- We send it tokens, meaning we do the tokenization ourself, so we need a tokenizer
- We don't yet read it out of the GGUF (TODO), so we need an HF repo with `tokenizer.json` et al
If the build step also builds llama_cpp libraries into `target/release` ("libllama.so", "libggml.so", "libggml-base.so", "libggml-cpu.so", "libggml-cuda.so"), then `tio` will need to find those at runtime. Set `LD_LIBRARY_PATH`, and be sure to deploy them alongside the `tio` binary.
...@@ -62,6 +62,15 @@ pub struct Flags { ...@@ -62,6 +62,15 @@ pub struct Flags {
#[arg(long)] #[arg(long)]
pub model_name: Option<String>, pub model_name: Option<String>,
/// llamacpp only
///
/// The path to the tokenizer and model config because:
/// - llama_cpp only runs GGUF files
/// - our engine is a 'core' engine in that we do the tokenization, so we need the vocab
/// - TODO: we don't yet extract that from the GGUF. Once we do we can remove this flag.
#[arg(long)]
pub model_config: Option<PathBuf>,
/// sglang only /// sglang only
/// ///
/// How many GPUs to use at once, total across all nodes. /// How many GPUs to use at once, total across all nodes.
...@@ -295,6 +304,43 @@ pub async fn run( ...@@ -295,6 +304,43 @@ pub async fn run(
card: Box::new(card), card: Box::new(card),
} }
} }
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
use anyhow::Context;
use triton_distributed_llm::engines::llamacpp;
let Some(model_path) = model_path else {
anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>");
};
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
let card = match flags.model_config {
None => {
anyhow::bail!("Pass --model-config so we can find the tokenizer, should be an HF checkout.");
}
Some(card_path) => {
if !card_path.is_dir() {
anyhow::bail!(
"--model-config should be a Hugging Face repo checkout directory."
);
}
ModelDeploymentCard::from_local_path(&card_path, model_name.as_deref())
.await
.with_context(|| {
format!(
"Failed loading ModelDeploymentCard from {}",
card_path.display()
)
})?
}
};
let engine = llamacpp::make_engine(cancel_token.clone(), &model_path).await?;
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
}; };
match in_opt { match in_opt {
......
...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull; ...@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio"; const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]"; const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
logging::init(); logging::init();
......
...@@ -79,6 +79,10 @@ pub enum Output { ...@@ -79,6 +79,10 @@ pub enum Output {
#[cfg(feature = "sglang")] #[cfg(feature = "sglang")]
/// Run inference using sglang /// Run inference using sglang
SgLang, SgLang,
#[cfg(feature = "llamacpp")]
/// Run inference using llama.cpp
LlamaCpp,
} }
impl TryFrom<&str> for Output { impl TryFrom<&str> for Output {
...@@ -92,6 +96,9 @@ impl TryFrom<&str> for Output { ...@@ -92,6 +96,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "sglang")] #[cfg(feature = "sglang")]
"sglang" => Ok(Output::SgLang), "sglang" => Ok(Output::SgLang),
#[cfg(feature = "llamacpp")]
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore), "echo_core" => Ok(Output::EchoCore),
...@@ -114,6 +121,9 @@ impl fmt::Display for Output { ...@@ -114,6 +121,9 @@ impl fmt::Display for Output {
#[cfg(feature = "sglang")] #[cfg(feature = "sglang")]
Output::SgLang => "sglang", Output::SgLang => "sglang",
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => "llamacpp",
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core", Output::EchoCore => "echo_core",
......
...@@ -443,6 +443,29 @@ version = "1.6.0" ...@@ -443,6 +443,29 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
"which",
]
[[package]] [[package]]
name = "bindgen_cuda" name = "bindgen_cuda"
version = "0.1.5" version = "0.1.5"
...@@ -711,6 +734,15 @@ dependencies = [ ...@@ -711,6 +734,15 @@ dependencies = [
"shlex", "shlex",
] ]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]] [[package]]
name = "cfg-expr" name = "cfg-expr"
version = "0.15.8" version = "0.15.8"
...@@ -768,6 +800,17 @@ dependencies = [ ...@@ -768,6 +800,17 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]] [[package]]
name = "clap" name = "clap"
version = "4.5.31" version = "4.5.31"
...@@ -1420,6 +1463,26 @@ dependencies = [ ...@@ -1420,6 +1463,26 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.2" version = "1.0.2"
...@@ -1535,6 +1598,15 @@ dependencies = [ ...@@ -1535,6 +1598,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "find_cuda_helper"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
dependencies = [
"glob",
]
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
version = "0.5.7" version = "0.5.7"
...@@ -2053,6 +2125,15 @@ dependencies = [ ...@@ -2053,6 +2125,15 @@ dependencies = [
"ureq", "ureq",
] ]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.2.0"
...@@ -2595,6 +2676,12 @@ version = "1.5.0" ...@@ -2595,6 +2676,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "lebe" name = "lebe"
version = "0.5.2" version = "0.5.2"
...@@ -2651,6 +2738,33 @@ version = "0.7.4" ...@@ -2651,6 +2738,33 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llama-cpp-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [
"bindgen",
"cc",
"cmake",
"find_cuda_helper",
"glob",
"walkdir",
]
[[package]] [[package]]
name = "llguidance" name = "llguidance"
version = "0.4.1" version = "0.4.1"
...@@ -2663,7 +2777,7 @@ dependencies = [ ...@@ -2663,7 +2777,7 @@ dependencies = [
"instant", "instant",
"referencing", "referencing",
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"toktrie 0.1.0", "toktrie 0.1.0",
...@@ -2981,7 +3095,7 @@ dependencies = [ ...@@ -2981,7 +3095,7 @@ dependencies = [
"regex", "regex",
"regex-automata 0.4.9", "regex-automata 0.4.9",
"reqwest", "reqwest",
"rustc-hash", "rustc-hash 2.1.1",
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
...@@ -3932,7 +4046,7 @@ dependencies = [ ...@@ -3932,7 +4046,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"quinn-proto", "quinn-proto",
"quinn-udp", "quinn-udp",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"socket2", "socket2",
"thiserror 2.0.11", "thiserror 2.0.11",
...@@ -3950,7 +4064,7 @@ dependencies = [ ...@@ -3950,7 +4064,7 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
"rand", "rand",
"ring", "ring",
"rustc-hash", "rustc-hash 2.1.1",
"rustls", "rustls",
"rustls-pki-types", "rustls-pki-types",
"slab", "slab",
...@@ -4319,6 +4433,12 @@ version = "0.1.24" ...@@ -4319,6 +4433,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
...@@ -5348,7 +5468,7 @@ dependencies = [ ...@@ -5348,7 +5468,7 @@ dependencies = [
"anyhow", "anyhow",
"bytemuck", "bytemuck",
"bytemuck_derive", "bytemuck_derive",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
] ]
...@@ -5373,7 +5493,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497 ...@@ -5373,7 +5493,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497
dependencies = [ dependencies = [
"anyhow", "anyhow",
"log", "log",
"rustc-hash", "rustc-hash 2.1.1",
"serde", "serde",
"serde_json", "serde_json",
"tokenizers", "tokenizers",
...@@ -5631,6 +5751,7 @@ dependencies = [ ...@@ -5631,6 +5751,7 @@ dependencies = [
"insta", "insta",
"itertools 0.14.0", "itertools 0.14.0",
"libc", "libc",
"llama-cpp-2",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
"mistralrs", "mistralrs",
...@@ -6128,6 +6249,18 @@ version = "0.1.8" ...@@ -6128,6 +6249,18 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -31,10 +31,13 @@ homepage.workspace = true ...@@ -31,10 +31,13 @@ homepage.workspace = true
[features] [features]
mistralrs = ["dep:mistralrs"] mistralrs = ["dep:mistralrs"]
metal = ["mistralrs/metal"] llamacpp = ["dep:llama-cpp-2"]
cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"]
sglang = ["dep:async_zmq"] sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"]
vulkan = ["llama-cpp-2/vulkan"]
[workspace.dependencies] [workspace.dependencies]
# local or crates.io # local or crates.io
...@@ -116,6 +119,9 @@ pyo3 = { version = "0.23.3", default-features = false, features = [ ...@@ -116,6 +119,9 @@ pyo3 = { version = "0.23.3", default-features = false, features = [
] } ] }
serde-pickle = "1.2.0" serde-pickle = "1.2.0"
# llamacpp
llama-cpp-2 = { version = "0.1.86", optional = true }
# tokenizers # tokenizers
tokenizers = { version = "0.21.0", default-features = false, features = [ tokenizers = { version = "0.21.0", default-features = false, features = [
"onig", "onig",
......
...@@ -18,3 +18,6 @@ pub mod mistralrs; ...@@ -18,3 +18,6 @@ pub mod mistralrs;
#[cfg(feature = "sglang")] #[cfg(feature = "sglang")]
pub mod sglang; pub mod sglang;
#[cfg(feature = "llamacpp")]
pub mod llamacpp;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
num::NonZeroU32,
path::Path,
sync::{Arc, Mutex, OnceLock},
};
use anyhow::Context;
use async_stream::stream;
use async_trait::async_trait;
use llama_cpp_2::{
context::{params::LlamaContextParams, LlamaContext},
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, LlamaModel},
sampling::LlamaSampler,
token::LlamaToken,
};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::CancellationToken;
use crate::backend::ExecutionContext;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use crate::protocols::common::preprocessor::PreprocessedRequest;
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
// I'm not entirely sure what this is. The model context size surely comes from the GGUF??
const CONTEXT_SIZE: u32 = 8192;
static LLAMA_BACKEND: tokio::sync::OnceCell<LlamaBackend> = tokio::sync::OnceCell::const_new();
pub(crate) static LLAMA_MODEL: tokio::sync::OnceCell<LlamaModel> =
tokio::sync::OnceCell::const_new();
const NUM_CONTEXTS: usize = 3;
static LLAMA_CONTEXTS: [OnceLock<Mutex<ContextWrapper>>; NUM_CONTEXTS] =
[OnceLock::new(), OnceLock::new(), OnceLock::new()];
// Newtype to simplify LlamaContext lifetime
#[derive(Debug)]
struct ContextWrapper(LlamaContext<'static>);
unsafe impl Send for ContextWrapper {} // LlamaContext has a NonNull which is !Send
unsafe impl Sync for ContextWrapper {} // LlamaContext has a NonNull which is !Sync
pub async fn make_engine(
cancel_token: CancellationToken,
model_path: &Path,
) -> pipeline_error::Result<ExecutionContext> {
let engine = LlamacppEngine::new(cancel_token, model_path).await?;
let engine: ExecutionContext = Arc::new(engine);
Ok(engine)
}
struct WorkRequest {
request: PreprocessedRequest,
response_channel: tokio::sync::mpsc::Sender<Annotated<LLMEngineOutput>>,
}
struct LlamacppEngine {
cancel_token: CancellationToken,
req_tx: tokio::sync::mpsc::Sender<WorkRequest>,
}
impl LlamacppEngine {
async fn new(
cancel_token: CancellationToken,
model_path: &Path,
) -> pipeline_error::Result<Self> {
let backend = LlamaBackend::init()?;
let model = load_model(&backend, model_path)?;
LLAMA_MODEL.set(model)?;
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
// Safety: NonZeroU32::new only errors if we give it a zero
let context_size = NonZeroU32::new(CONTEXT_SIZE).unwrap();
let llama_ctx_params = LlamaContextParams::default().with_n_ctx(Some(context_size));
for (i, ctx_holder) in LLAMA_CONTEXTS.iter().enumerate().take(NUM_CONTEXTS) {
let llama_ctx = LLAMA_MODEL
.get()
.unwrap() // Safety: We put it in a few lines up
.new_context(&backend, llama_ctx_params.clone())
.with_context(|| "unable to create the llama_context")?;
let _ = ctx_holder.set(Mutex::new(ContextWrapper(llama_ctx)));
let _ = ctx_set.send(i).await;
}
LLAMA_BACKEND.set(backend)?;
let (req_tx, req_rx) = tokio::sync::mpsc::channel(2);
let ct = cancel_token.clone();
tokio::task::spawn(worker(ct, req_rx, ctx_get, ctx_set));
Ok(LlamacppEngine {
cancel_token,
req_tx,
})
}
}
fn load_model(backend: &LlamaBackend, model_path: &Path) -> anyhow::Result<LlamaModel> {
let model_params = {
if cfg!(any(feature = "cuda", feature = "vulkan")) {
LlamaModelParams::default().with_n_gpu_layers(1000)
} else {
LlamaModelParams::default()
}
};
LlamaModel::load_from_file(backend, model_path, &model_params)
.with_context(|| "unable to load model")
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for LlamacppEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (tx, mut rx) = tokio::sync::mpsc::channel(128);
let work_request = WorkRequest {
request,
response_channel: tx,
};
self.req_tx.send(work_request).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!(request_id, "LlamacppEngine.generate stopped by cancel token");
break;
}
from_llamacpp = rx.recv() => {
match from_llamacpp {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// Run this in a thread
async fn worker(
cancel_token: CancellationToken,
mut req_rx: tokio::sync::mpsc::Receiver<WorkRequest>,
mut ctx_get: tokio::sync::mpsc::Receiver<usize>,
ctx_set: tokio::sync::mpsc::Sender<usize>,
) {
loop {
let maybe_work_request = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_work_request = req_rx.recv() => {
maybe_work_request
}
};
let Some(work_request) = maybe_work_request else {
tracing::error!("llamacpp work request sender channel closed. Worker exit");
break;
};
// will block if there are already NUM_CONTEXTS requests in flight
let Some(ctx_pos) = ctx_get.recv().await else {
unreachable!("We don't close ctx_set");
};
let ct = cancel_token.clone();
let inner_ctx_set = ctx_set.clone();
tokio::task::spawn_blocking(move || {
let mut ctx = LLAMA_CONTEXTS[ctx_pos].get().unwrap().lock().unwrap();
if let Err(err) = run_request(ct, work_request, &mut ctx) {
tracing::error!("run_request error: {err:#}");
}
let _ = inner_ctx_set.blocking_send(ctx_pos);
});
}
}
fn run_request(
cancel_token: CancellationToken,
work_request: WorkRequest,
llama_context: &mut ContextWrapper,
) -> anyhow::Result<()> {
let tokens_list: Vec<LlamaToken> = work_request
.request
.token_ids
.into_iter()
.map(|u| LlamaToken::new(u as i32))
.collect();
let limit = DEFAULT_MAX_TOKENS; // - prompt_tokens;
let max_output_tokens = std::cmp::min(
work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(limit),
limit,
);
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch
.add(token, i, &[0], is_last)
.with_context(|| format!("Failed adding token pos {i} to batch"))?;
}
// "decode" means "run forward pass"
llama_context
.0
.decode(&mut batch)
.with_context(|| "llama_decode failed on first pass")?;
let mut sampler = LlamaSampler::greedy();
let mut n_cur = batch.n_tokens() as u32;
let mut used_output_tokens = 0;
while !cancel_token.is_cancelled() {
// sample the next token
let n_tokens = batch.n_tokens();
let token = sampler.sample(&llama_context.0, n_tokens - 1);
sampler.accept(token);
// is it an end of stream?
// This is probably safe for concurrent access
if LLAMA_MODEL.get().unwrap().is_eog_token(token) {
work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::stop()))
.with_context(|| "Failed sending stop to response_channel")?;
break;
}
let engine_out = LLMEngineOutput {
// todo - propagate mdcsum
token_ids: vec![token.0 as u32],
tokens: None,
text: None,
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs
finish_reason: None,
};
work_request
.response_channel
.blocking_send(Annotated::from_data(engine_out))
.with_context(|| "Failed forwarding engine output to response_channel")?;
batch.clear();
if let Err(err) = batch.add(token, n_cur as i32, &[0], true) {
let err_msg = format!(
"batch add error, probably insufficient space in buffer, aborting request. {err}."
);
tracing::error!(err_msg);
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::error(err_msg)));
break;
}
n_cur += 1;
used_output_tokens += 1;
if used_output_tokens > max_output_tokens {
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::length()));
break;
}
llama_context
.0
.decode(&mut batch)
.with_context(|| "llama_decode failed during loop")?;
}
if cancel_token.is_cancelled() {
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::stop()));
}
// Clean context for next use
llama_context.0.clear_kv_cache();
llama_context.0.reset_timings();
Ok(())
}
...@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn}; ...@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated; use triton_distributed_runtime::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
...@@ -160,14 +160,14 @@ impl MistralRsEngine { ...@@ -160,14 +160,14 @@ impl MistralRsEngine {
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> for MistralRsEngine > for MistralRsEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<ChatCompletionRequest>, request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment