Commit e584e96f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: llama.cpp engine for tio (#298)

Docs in README
parent b20ef999
......@@ -443,6 +443,29 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.11.0",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
"which",
]
[[package]]
name = "bindgen_cuda"
version = "0.1.5"
......@@ -686,6 +709,15 @@ dependencies = [
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-expr"
version = "0.15.8"
......@@ -743,6 +775,17 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.5.30"
......@@ -784,6 +827,15 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6"
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "color_quant"
version = "1.1.0"
......@@ -1398,6 +1450,26 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "equivalent"
version = "1.0.2"
......@@ -1513,6 +1585,15 @@ dependencies = [
"version_check",
]
[[package]]
name = "find_cuda_helper"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
dependencies = [
"glob",
]
[[package]]
name = "fixedbitset"
version = "0.5.7"
......@@ -1996,6 +2077,15 @@ dependencies = [
"ureq",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "http"
version = "1.2.0"
......@@ -2509,6 +2599,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lebe"
version = "0.5.2"
......@@ -2559,6 +2655,33 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llama-cpp-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [
"bindgen",
"cc",
"cmake",
"find_cuda_helper",
"glob",
"walkdir",
]
[[package]]
name = "llguidance"
version = "0.4.1"
......@@ -2571,7 +2694,7 @@ dependencies = [
"instant",
"referencing",
"regex-syntax 0.8.5",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"toktrie 0.1.0",
......@@ -2889,7 +3012,7 @@ dependencies = [
"regex",
"regex-automata 0.4.9",
"reqwest",
"rustc-hash",
"rustc-hash 2.1.1",
"safetensors",
"schemars",
"serde",
......@@ -3810,7 +3933,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.11",
......@@ -3828,7 +3951,7 @@ dependencies = [
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
......@@ -4170,6 +4293,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
......@@ -5186,7 +5315,7 @@ dependencies = [
"anyhow",
"bytemuck",
"bytemuck_derive",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
]
......@@ -5211,7 +5340,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497
dependencies = [
"anyhow",
"log",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"tokenizers",
......@@ -5467,6 +5596,7 @@ dependencies = [
"indexmap 2.7.1",
"itertools 0.14.0",
"libc",
"llama-cpp-2",
"minijinja",
"minijinja-contrib",
"mistralrs",
......@@ -5938,6 +6068,18 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "winapi"
version = "0.2.8"
......
......@@ -24,6 +24,7 @@ license = "Apache-2.0"
[features]
mistralrs = ["triton-distributed-llm/mistralrs"]
sglang = ["triton-distributed-llm/sglang", "dep:netlink-packet-route", "dep:rtnetlink"]
llamacpp = ["triton-distributed-llm/llamacpp"]
cuda = ["triton-distributed-llm/cuda"]
metal = ["triton-distributed-llm/metal"]
......
......@@ -48,12 +48,12 @@ curl -d '{"model": "Llama-3.2-1B-Instruct-Q4_K_M", "max_tokens": 2049, "messages
Node 1:
```
tio in=http out=tdr://ns/backend/mistralrs
tio in=http out=tdr://llama3B_pool
```
Node 2:
```
tio in=tdr://ns/backend/mistralrs out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct
tio in=tdr://llama3B_pool out=mistralrs ~/llm_models/Llama-3.2-3B-Instruct
```
This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.
......@@ -64,6 +64,8 @@ Run `tio --help` for more options.
## sglang
1. Setup the python virtual env:
```
uv venv
source .venv/bin/activate
......@@ -71,3 +73,36 @@ uv pip install pip
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
```
2. Build
```
cargo build --release --features sglang
```
3. Run
Any example above using `out=sglang` will work, but our sglang backend is also multi-gpu and multi-node.
Node 1:
```
tio in=http out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 0 --dist-init-addr 10.217.98.122:9876
```
Node 2:
```
tio in=none out=sglang --model-path ~/llm_models/DeepSeek-R1-Distill-Llama-70B/ --tensor-parallel-size 8 --num-nodes 2 --node-rank 1 --dist-init-addr 10.217.98.122:9876
```
## llama_cpp
- `cargo build --release --features llamacpp,cuda`
- `tio out=llama_cpp --model-path ~/llm_models/Llama-3.2-3B-Instruct-Q6_K.gguf --model-config ~/llm_models/Llama-3.2-3B-Instruct/`
The extra `--model-config` flag is because:
- llama_cpp only runs GGUF
- We send it tokens, meaning we do the tokenization ourself, so we need a tokenizer
- We don't yet read it out of the GGUF (TODO), so we need an HF repo with `tokenizer.json` et al
If the build step also builds llama_cpp libraries into `target/release` ("libllama.so", "libggml.so", "libggml-base.so", "libggml-cpu.so", "libggml-cuda.so"), then `tio` will need to find those at runtime. Set `LD_LIBRARY_PATH`, and be sure to deploy them alongside the `tio` binary.
......@@ -62,6 +62,15 @@ pub struct Flags {
#[arg(long)]
pub model_name: Option<String>,
/// llamacpp only
///
/// The path to the tokenizer and model config because:
/// - llama_cpp only runs GGUF files
/// - our engine is a 'core' engine in that we do the tokenization, so we need the vocab
/// - TODO: we don't yet extract that from the GGUF. Once we do we can remove this flag.
#[arg(long)]
pub model_config: Option<PathBuf>,
/// sglang only
///
/// How many GPUs to use at once, total across all nodes.
......@@ -295,6 +304,43 @@ pub async fn run(
card: Box::new(card),
}
}
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => {
use anyhow::Context;
use triton_distributed_llm::engines::llamacpp;
let Some(model_path) = model_path else {
anyhow::bail!("out=llamacpp requires flag --model-path=<full-path-to-model-gguf>");
};
if !model_path.is_file() {
anyhow::bail!("--model-path should refer to a GGUF file. llama_cpp does not support safetensors.");
}
let card = match flags.model_config {
None => {
anyhow::bail!("Pass --model-config so we can find the tokenizer, should be an HF checkout.");
}
Some(card_path) => {
if !card_path.is_dir() {
anyhow::bail!(
"--model-config should be a Hugging Face repo checkout directory."
);
}
ModelDeploymentCard::from_local_path(&card_path, model_name.as_deref())
.await
.with_context(|| {
format!(
"Failed loading ModelDeploymentCard from {}",
card_path.display()
)
})?
}
};
let engine = llamacpp::make_engine(cancel_token.clone(), &model_path).await?;
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine,
card: Box::new(card),
}
}
};
match in_opt {
......
......@@ -41,7 +41,7 @@ const DEFAULT_OUT: Output = Output::EchoFull;
const ZMQ_SOCKET_PREFIX: &str = "tio";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
const USAGE: &str = "USAGE: tio in=[http|text|tdr://<path>|none] out=[mistralrs|sglang|llamacpp|echo_full|echo_core] [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--tensor-parallel-size=1] [--num-nodes=1] [--node-rank=0] [--dist-init-addr=127.0.0.1:9876] [--base-gpu-id=0]";
fn main() -> anyhow::Result<()> {
logging::init();
......
......@@ -79,6 +79,10 @@ pub enum Output {
#[cfg(feature = "sglang")]
/// Run inference using sglang
SgLang,
#[cfg(feature = "llamacpp")]
/// Run inference using llama.cpp
LlamaCpp,
}
impl TryFrom<&str> for Output {
......@@ -92,6 +96,9 @@ impl TryFrom<&str> for Output {
#[cfg(feature = "sglang")]
"sglang" => Ok(Output::SgLang),
#[cfg(feature = "llamacpp")]
"llamacpp" | "llama_cpp" => Ok(Output::LlamaCpp),
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
......@@ -114,6 +121,9 @@ impl fmt::Display for Output {
#[cfg(feature = "sglang")]
Output::SgLang => "sglang",
#[cfg(feature = "llamacpp")]
Output::LlamaCpp => "llamacpp",
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
......
......@@ -443,6 +443,29 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.98",
"which",
]
[[package]]
name = "bindgen_cuda"
version = "0.1.5"
......@@ -711,6 +734,15 @@ dependencies = [
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-expr"
version = "0.15.8"
......@@ -768,6 +800,17 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.5.31"
......@@ -1420,6 +1463,26 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.98",
]
[[package]]
name = "equivalent"
version = "1.0.2"
......@@ -1535,6 +1598,15 @@ dependencies = [
"version_check",
]
[[package]]
name = "find_cuda_helper"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
dependencies = [
"glob",
]
[[package]]
name = "fixedbitset"
version = "0.5.7"
......@@ -2053,6 +2125,15 @@ dependencies = [
"ureq",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "http"
version = "1.2.0"
......@@ -2595,6 +2676,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "lebe"
version = "0.5.2"
......@@ -2651,6 +2738,33 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "llama-cpp-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
dependencies = [
"bindgen",
"cc",
"cmake",
"find_cuda_helper",
"glob",
"walkdir",
]
[[package]]
name = "llguidance"
version = "0.4.1"
......@@ -2663,7 +2777,7 @@ dependencies = [
"instant",
"referencing",
"regex-syntax 0.8.5",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"toktrie 0.1.0",
......@@ -2981,7 +3095,7 @@ dependencies = [
"regex",
"regex-automata 0.4.9",
"reqwest",
"rustc-hash",
"rustc-hash 2.1.1",
"safetensors",
"schemars",
"serde",
......@@ -3932,7 +4046,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.11",
......@@ -3950,7 +4064,7 @@ dependencies = [
"getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
"rustc-hash 2.1.1",
"rustls",
"rustls-pki-types",
"slab",
......@@ -4319,6 +4433,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.1.1"
......@@ -5348,7 +5468,7 @@ dependencies = [
"anyhow",
"bytemuck",
"bytemuck_derive",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
]
......@@ -5373,7 +5493,7 @@ source = "git+https://github.com/microsoft/llguidance?rev=cfef3df97372a7b84d7497
dependencies = [
"anyhow",
"log",
"rustc-hash",
"rustc-hash 2.1.1",
"serde",
"serde_json",
"tokenizers",
......@@ -5631,6 +5751,7 @@ dependencies = [
"insta",
"itertools 0.14.0",
"libc",
"llama-cpp-2",
"minijinja",
"minijinja-contrib",
"mistralrs",
......@@ -6128,6 +6249,18 @@ version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082"
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]]
name = "winapi"
version = "0.2.8"
......
......@@ -31,10 +31,13 @@ homepage.workspace = true
[features]
mistralrs = ["dep:mistralrs"]
metal = ["mistralrs/metal"]
cuda = ["mistralrs/cuda"]
sentencepiece = ["dep:sentencepiece"]
llamacpp = ["dep:llama-cpp-2"]
sglang = ["dep:async_zmq"]
sentencepiece = ["dep:sentencepiece"]
cuda = ["mistralrs/cuda", "llama-cpp-2/cuda"]
metal = ["mistralrs/metal", "llama-cpp-2/metal"]
vulkan = ["llama-cpp-2/vulkan"]
[workspace.dependencies]
# local or crates.io
......@@ -116,6 +119,9 @@ pyo3 = { version = "0.23.3", default-features = false, features = [
] }
serde-pickle = "1.2.0"
# llamacpp
llama-cpp-2 = { version = "0.1.86", optional = true }
# tokenizers
tokenizers = { version = "0.21.0", default-features = false, features = [
"onig",
......
......@@ -18,3 +18,6 @@ pub mod mistralrs;
#[cfg(feature = "sglang")]
pub mod sglang;
#[cfg(feature = "llamacpp")]
pub mod llamacpp;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
num::NonZeroU32,
path::Path,
sync::{Arc, Mutex, OnceLock},
};
use anyhow::Context;
use async_stream::stream;
use async_trait::async_trait;
use llama_cpp_2::{
context::{params::LlamaContextParams, LlamaContext},
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, LlamaModel},
sampling::LlamaSampler,
token::LlamaToken,
};
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
use triton_distributed_runtime::pipeline::error as pipeline_error;
use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use triton_distributed_runtime::CancellationToken;
use crate::backend::ExecutionContext;
use crate::protocols::common::llm_backend::{BackendInput, LLMEngineOutput};
use crate::protocols::common::preprocessor::PreprocessedRequest;
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: u32 = 8192;
// I'm not entirely sure what this is. The model context size surely comes from the GGUF??
const CONTEXT_SIZE: u32 = 8192;
static LLAMA_BACKEND: tokio::sync::OnceCell<LlamaBackend> = tokio::sync::OnceCell::const_new();
pub(crate) static LLAMA_MODEL: tokio::sync::OnceCell<LlamaModel> =
tokio::sync::OnceCell::const_new();
const NUM_CONTEXTS: usize = 3;
static LLAMA_CONTEXTS: [OnceLock<Mutex<ContextWrapper>>; NUM_CONTEXTS] =
[OnceLock::new(), OnceLock::new(), OnceLock::new()];
// Newtype to simplify LlamaContext lifetime
#[derive(Debug)]
struct ContextWrapper(LlamaContext<'static>);
unsafe impl Send for ContextWrapper {} // LlamaContext has a NonNull which is !Send
unsafe impl Sync for ContextWrapper {} // LlamaContext has a NonNull which is !Sync
pub async fn make_engine(
cancel_token: CancellationToken,
model_path: &Path,
) -> pipeline_error::Result<ExecutionContext> {
let engine = LlamacppEngine::new(cancel_token, model_path).await?;
let engine: ExecutionContext = Arc::new(engine);
Ok(engine)
}
struct WorkRequest {
request: PreprocessedRequest,
response_channel: tokio::sync::mpsc::Sender<Annotated<LLMEngineOutput>>,
}
struct LlamacppEngine {
cancel_token: CancellationToken,
req_tx: tokio::sync::mpsc::Sender<WorkRequest>,
}
impl LlamacppEngine {
async fn new(
cancel_token: CancellationToken,
model_path: &Path,
) -> pipeline_error::Result<Self> {
let backend = LlamaBackend::init()?;
let model = load_model(&backend, model_path)?;
LLAMA_MODEL.set(model)?;
let (ctx_set, ctx_get) = tokio::sync::mpsc::channel(NUM_CONTEXTS);
// Safety: NonZeroU32::new only errors if we give it a zero
let context_size = NonZeroU32::new(CONTEXT_SIZE).unwrap();
let llama_ctx_params = LlamaContextParams::default().with_n_ctx(Some(context_size));
for (i, ctx_holder) in LLAMA_CONTEXTS.iter().enumerate().take(NUM_CONTEXTS) {
let llama_ctx = LLAMA_MODEL
.get()
.unwrap() // Safety: We put it in a few lines up
.new_context(&backend, llama_ctx_params.clone())
.with_context(|| "unable to create the llama_context")?;
let _ = ctx_holder.set(Mutex::new(ContextWrapper(llama_ctx)));
let _ = ctx_set.send(i).await;
}
LLAMA_BACKEND.set(backend)?;
let (req_tx, req_rx) = tokio::sync::mpsc::channel(2);
let ct = cancel_token.clone();
tokio::task::spawn(worker(ct, req_rx, ctx_get, ctx_set));
Ok(LlamacppEngine {
cancel_token,
req_tx,
})
}
}
fn load_model(backend: &LlamaBackend, model_path: &Path) -> anyhow::Result<LlamaModel> {
let model_params = {
if cfg!(any(feature = "cuda", feature = "vulkan")) {
LlamaModelParams::default().with_n_gpu_layers(1000)
} else {
LlamaModelParams::default()
}
};
LlamaModel::load_from_file(backend, model_path, &model_params)
.with_context(|| "unable to load model")
}
#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for LlamacppEngine
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let (request, context) = request.into_parts();
let ctx = context.context();
let request_id = ctx.id().to_string();
let (tx, mut rx) = tokio::sync::mpsc::channel(128);
let work_request = WorkRequest {
request,
response_channel: tx,
};
self.req_tx.send(work_request).await?;
let cancel_token = self.cancel_token.clone();
let output = stream! {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!(request_id, "LlamacppEngine.generate stopped by cancel token");
break;
}
from_llamacpp = rx.recv() => {
match from_llamacpp {
Some(out) => {
yield out;
},
None => {
tracing::trace!(request_id, "generate: response channel closed");
break;
}
}
}
}
}
};
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
// Run this in a thread
async fn worker(
cancel_token: CancellationToken,
mut req_rx: tokio::sync::mpsc::Receiver<WorkRequest>,
mut ctx_get: tokio::sync::mpsc::Receiver<usize>,
ctx_set: tokio::sync::mpsc::Sender<usize>,
) {
loop {
let maybe_work_request = tokio::select! {
_ = cancel_token.cancelled() => {
break;
}
maybe_work_request = req_rx.recv() => {
maybe_work_request
}
};
let Some(work_request) = maybe_work_request else {
tracing::error!("llamacpp work request sender channel closed. Worker exit");
break;
};
// will block if there are already NUM_CONTEXTS requests in flight
let Some(ctx_pos) = ctx_get.recv().await else {
unreachable!("We don't close ctx_set");
};
let ct = cancel_token.clone();
let inner_ctx_set = ctx_set.clone();
tokio::task::spawn_blocking(move || {
let mut ctx = LLAMA_CONTEXTS[ctx_pos].get().unwrap().lock().unwrap();
if let Err(err) = run_request(ct, work_request, &mut ctx) {
tracing::error!("run_request error: {err:#}");
}
let _ = inner_ctx_set.blocking_send(ctx_pos);
});
}
}
fn run_request(
cancel_token: CancellationToken,
work_request: WorkRequest,
llama_context: &mut ContextWrapper,
) -> anyhow::Result<()> {
let tokens_list: Vec<LlamaToken> = work_request
.request
.token_ids
.into_iter()
.map(|u| LlamaToken::new(u as i32))
.collect();
let limit = DEFAULT_MAX_TOKENS; // - prompt_tokens;
let max_output_tokens = std::cmp::min(
work_request
.request
.stop_conditions
.max_tokens
.unwrap_or(limit),
limit,
);
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch
.add(token, i, &[0], is_last)
.with_context(|| format!("Failed adding token pos {i} to batch"))?;
}
// "decode" means "run forward pass"
llama_context
.0
.decode(&mut batch)
.with_context(|| "llama_decode failed on first pass")?;
let mut sampler = LlamaSampler::greedy();
let mut n_cur = batch.n_tokens() as u32;
let mut used_output_tokens = 0;
while !cancel_token.is_cancelled() {
// sample the next token
let n_tokens = batch.n_tokens();
let token = sampler.sample(&llama_context.0, n_tokens - 1);
sampler.accept(token);
// is it an end of stream?
// This is probably safe for concurrent access
if LLAMA_MODEL.get().unwrap().is_eog_token(token) {
work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::stop()))
.with_context(|| "Failed sending stop to response_channel")?;
break;
}
let engine_out = LLMEngineOutput {
// todo - propagate mdcsum
token_ids: vec![token.0 as u32],
tokens: None,
text: None,
//text: if output.text.is_empty() { None } else { Some(output.text) },
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs
finish_reason: None,
};
work_request
.response_channel
.blocking_send(Annotated::from_data(engine_out))
.with_context(|| "Failed forwarding engine output to response_channel")?;
batch.clear();
if let Err(err) = batch.add(token, n_cur as i32, &[0], true) {
let err_msg = format!(
"batch add error, probably insufficient space in buffer, aborting request. {err}."
);
tracing::error!(err_msg);
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::error(err_msg)));
break;
}
n_cur += 1;
used_output_tokens += 1;
if used_output_tokens > max_output_tokens {
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::length()));
break;
}
llama_context
.0
.decode(&mut batch)
.with_context(|| "llama_decode failed during loop")?;
}
if cancel_token.is_cancelled() {
let _ = work_request
.response_channel
.blocking_send(Annotated::from_data(LLMEngineOutput::stop()));
}
// Clean context for next use
llama_context.0.clear_kv_cache();
llama_context.0.reset_timings();
Ok(())
}
......@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, NvCreateChatCompletionStreamResponse,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
......@@ -160,14 +160,14 @@ impl MistralRsEngine {
#[async_trait]
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<ChatCompletionRequest>,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment