Commit d99b188d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Upgrade mistral.rs (#97)

- Latest from repo, many improvements
- Support most of the OpenAI request features (temperature, top_p, etc)
- Download models from Hugging Face if necessary
parent e5db9e86
...@@ -638,7 +638,7 @@ dependencies = [ ...@@ -638,7 +638,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
...@@ -684,7 +684,7 @@ dependencies = [ ...@@ -684,7 +684,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -692,7 +692,7 @@ dependencies = [ ...@@ -692,7 +692,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -703,7 +703,7 @@ dependencies = [ ...@@ -703,7 +703,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -1384,6 +1384,12 @@ dependencies = [ ...@@ -1384,6 +1384,12 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.19" version = "1.0.19"
...@@ -2669,6 +2675,19 @@ dependencies = [ ...@@ -2669,6 +2675,19 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "inventory" name = "inventory"
version = "0.3.20" version = "0.3.20"
...@@ -3164,7 +3183,7 @@ dependencies = [ ...@@ -3164,7 +3183,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core", "candle-core",
...@@ -3185,7 +3204,7 @@ dependencies = [ ...@@ -3185,7 +3204,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3213,6 +3232,7 @@ dependencies = [ ...@@ -3213,6 +3232,7 @@ dependencies = [
"image", "image",
"indexmap 2.7.1", "indexmap 2.7.1",
"indicatif", "indicatif",
"interprocess",
"itertools 0.13.0", "itertools 0.13.0",
"llguidance", "llguidance",
"lrtable", "lrtable",
...@@ -3235,6 +3255,7 @@ dependencies = [ ...@@ -3235,6 +3255,7 @@ dependencies = [
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
"serde-big-array",
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
...@@ -3257,7 +3278,7 @@ dependencies = [ ...@@ -3257,7 +3278,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3272,7 +3293,7 @@ dependencies = [ ...@@ -3272,7 +3293,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3298,7 +3319,7 @@ dependencies = [ ...@@ -3298,7 +3319,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"image", "image",
...@@ -4363,6 +4384,12 @@ version = "0.5.5" ...@@ -4363,6 +4384,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.10" version = "0.5.10"
...@@ -4829,6 +4856,15 @@ dependencies = [ ...@@ -4829,6 +4856,15 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-pickle" name = "serde-pickle"
version = "1.2.0" version = "1.2.0"
...@@ -6272,6 +6308,12 @@ dependencies = [ ...@@ -6272,6 +6308,12 @@ dependencies = [
"rustix 0.38.44", "rustix 0.38.44",
] ]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -633,7 +633,7 @@ dependencies = [ ...@@ -633,7 +633,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
...@@ -679,7 +679,7 @@ dependencies = [ ...@@ -679,7 +679,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -687,7 +687,7 @@ dependencies = [ ...@@ -687,7 +687,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -698,7 +698,7 @@ dependencies = [ ...@@ -698,7 +698,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -1379,6 +1379,12 @@ dependencies = [ ...@@ -1379,6 +1379,12 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.18" version = "1.0.18"
...@@ -2664,6 +2670,19 @@ dependencies = [ ...@@ -2664,6 +2670,19 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "inventory" name = "inventory"
version = "0.3.20" version = "0.3.20"
...@@ -3139,7 +3158,7 @@ dependencies = [ ...@@ -3139,7 +3158,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core", "candle-core",
...@@ -3160,7 +3179,7 @@ dependencies = [ ...@@ -3160,7 +3179,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3188,6 +3207,7 @@ dependencies = [ ...@@ -3188,6 +3207,7 @@ dependencies = [
"image", "image",
"indexmap 2.7.1", "indexmap 2.7.1",
"indicatif", "indicatif",
"interprocess",
"itertools 0.13.0", "itertools 0.13.0",
"llguidance", "llguidance",
"lrtable", "lrtable",
...@@ -3210,6 +3230,7 @@ dependencies = [ ...@@ -3210,6 +3230,7 @@ dependencies = [
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
"serde-big-array",
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
...@@ -3232,7 +3253,7 @@ dependencies = [ ...@@ -3232,7 +3253,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3247,7 +3268,7 @@ dependencies = [ ...@@ -3247,7 +3268,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3273,7 +3294,7 @@ dependencies = [ ...@@ -3273,7 +3294,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"image", "image",
...@@ -4327,6 +4348,12 @@ version = "0.5.5" ...@@ -4327,6 +4348,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.9" version = "0.5.9"
...@@ -4780,6 +4807,15 @@ dependencies = [ ...@@ -4780,6 +4807,15 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-pickle" name = "serde-pickle"
version = "1.2.0" version = "1.2.0"
...@@ -6200,6 +6236,12 @@ dependencies = [ ...@@ -6200,6 +6236,12 @@ dependencies = [
"rustix", "rustix",
] ]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -28,6 +28,13 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ...@@ -28,6 +28,13 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
`cargo build --release --features mistralrs` `cargo build --release --features mistralrs`
## Quickstart
If you have an `HF_TOKEN` environment variable set, this will download Qwen2.5 3B from Hugging Face (6 GiB download) and start it in interactive mode:
```
./target/release/dynamo-run Qwen/Qwen2.5-3B-Instruct
```
## Download a model from Hugging Face ## Download a model from Hugging Face
For example one of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF For example one of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
...@@ -40,7 +47,7 @@ For example one of these should be fast and good quality on almost any machine: ...@@ -40,7 +47,7 @@ For example one of these should be fast and good quality on almost any machine:
*HTTP interface* *HTTP interface*
`./target/release/dynamo-run in=http --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf` `./target/release/dynamo-run in=http Llama-3.2-1B-Instruct-Q4_K_M.gguf`
List the models: `curl localhost:8080/v1/models` List the models: `curl localhost:8080/v1/models`
...@@ -63,7 +70,7 @@ dynamo-run in=dyn://llama3B_pool out=mistralrs ~/llm_models/Llama-3.2-3B-Instruc ...@@ -63,7 +70,7 @@ dynamo-run in=dyn://llama3B_pool out=mistralrs ~/llm_models/Llama-3.2-3B-Instruc
This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time. This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.
The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has three parts, and it matches the other node. The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node.
Run `dynamo-run --help` for more options. Run `dynamo-run --help` for more options.
......
...@@ -20,12 +20,17 @@ use std::str::FromStr; ...@@ -20,12 +20,17 @@ use std::str::FromStr;
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)] #[command(version, about, long_about = None)]
pub struct Flags { pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository. /// The model. The options depend on the engine.
/// For the `echo_full` engine omit the flag. ///
/// The full list - only mistralrs supports all three currently:
/// - Full path to a GGUF file
/// - Full path of a checked out Hugging Face repository containing safetensor files
/// - Name of a Hugging Face repository, e.g 'google/flan-t5-small'. The model will be
/// downloaded and cached.
#[arg(index = 1)] #[arg(index = 1)]
pub model_path_pos: Option<PathBuf>, pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>` // `--model-path`. The one above is `dynamo-run <positional-model-path>`
#[arg(long = "model-path")] #[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>, pub model_path_flag: Option<PathBuf>,
......
...@@ -83,7 +83,13 @@ pub async fn run( ...@@ -83,7 +83,13 @@ pub async fn run(
let model_path = flags let model_path = flags
.model_path_pos .model_path_pos
.or(flags.model_path_flag) .or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok()); .and_then(|p| {
if p.exists() {
p.canonicalize().ok()
} else {
Some(p)
}
});
// Serve the model under the name provided, or the name of the GGUF file or HF repo. // Serve the model under the name provided, or the name of the GGUF file or HF repo.
let model_name = flags let model_name = flags
.model_name .model_name
......
...@@ -658,7 +658,7 @@ dependencies = [ ...@@ -658,7 +658,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
...@@ -704,7 +704,7 @@ dependencies = [ ...@@ -704,7 +704,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -712,7 +712,7 @@ dependencies = [ ...@@ -712,7 +712,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -723,7 +723,7 @@ dependencies = [ ...@@ -723,7 +723,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe" source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -1392,6 +1392,12 @@ dependencies = [ ...@@ -1392,6 +1392,12 @@ dependencies = [
"syn 2.0.98", "syn 2.0.98",
] ]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.18" version = "1.0.18"
...@@ -2715,6 +2721,19 @@ dependencies = [ ...@@ -2715,6 +2721,19 @@ dependencies = [
"cfg-if 1.0.0", "cfg-if 1.0.0",
] ]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "inventory" name = "inventory"
version = "0.3.20" version = "0.3.20"
...@@ -3205,7 +3224,7 @@ dependencies = [ ...@@ -3205,7 +3224,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core", "candle-core",
...@@ -3226,7 +3245,7 @@ dependencies = [ ...@@ -3226,7 +3245,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3254,6 +3273,7 @@ dependencies = [ ...@@ -3254,6 +3273,7 @@ dependencies = [
"image", "image",
"indexmap 2.7.1", "indexmap 2.7.1",
"indicatif", "indicatif",
"interprocess",
"itertools 0.13.0", "itertools 0.13.0",
"llguidance", "llguidance",
"lrtable", "lrtable",
...@@ -3276,6 +3296,7 @@ dependencies = [ ...@@ -3276,6 +3296,7 @@ dependencies = [
"safetensors", "safetensors",
"schemars", "schemars",
"serde", "serde",
"serde-big-array",
"serde_json", "serde_json",
"serde_plain", "serde_plain",
"serde_yaml", "serde_yaml",
...@@ -3298,7 +3319,7 @@ dependencies = [ ...@@ -3298,7 +3319,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3313,7 +3334,7 @@ dependencies = [ ...@@ -3313,7 +3334,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3339,7 +3360,7 @@ dependencies = [ ...@@ -3339,7 +3360,7 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"image", "image",
...@@ -4432,6 +4453,12 @@ version = "0.5.5" ...@@ -4432,6 +4453,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.9" version = "0.5.9"
...@@ -4952,6 +4979,15 @@ dependencies = [ ...@@ -4952,6 +4979,15 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-pickle" name = "serde-pickle"
version = "1.2.0" version = "1.2.0"
...@@ -6393,6 +6429,12 @@ dependencies = [ ...@@ -6393,6 +6429,12 @@ dependencies = [
"rustix", "rustix",
] ]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.2.8" version = "0.2.8"
......
...@@ -115,7 +115,7 @@ prometheus = { version = "0.13" } ...@@ -115,7 +115,7 @@ prometheus = { version = "0.13" }
# mistralrs # mistralrs
either = { version = "1.13" } either = { version = "1.13" }
indexmap = { version = "2.6" } indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true } mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "a691154bb", optional = true }
# sglang # sglang
async_zmq = { version = "0.4.0", optional = true } async_zmq = { version = "0.4.0", optional = true }
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{cmp::min, num::NonZero, path::Path, sync::Arc}; use std::collections::HashMap;
use std::{cmp::min, env, num::NonZero, path::Path, sync::Arc};
use async_openai::types::FinishReason; use async_openai::types::FinishReason;
use async_stream::stream; use async_stream::stream;
...@@ -24,7 +25,8 @@ use mistralrs::{ ...@@ -24,7 +25,8 @@ use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting, AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder, GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig, ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource, Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens,
TokenSource,
}; };
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
...@@ -41,15 +43,12 @@ use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine ...@@ -41,15 +43,12 @@ use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine
/// If user does not provide a max_tokens limit prompt+output to this many /// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192; const DEFAULT_MAX_TOKENS: i32 = 8192;
/// TODO: tune. Presumably we read it from model's config.json?
const MAX_SEQ_LEN: usize = 4096;
// TODO: tune, maybe implement batching.
const MAX_BATCH_SIZE: usize = 2;
/// TODO: tune /// TODO: tune
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5; const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;
/// The environment variable which can hold the Hugging Face token, if any, in order
const HF_TOKEN_VARS: [&str; 3] = ["HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN"];
pub async fn make_engine( pub async fn make_engine(
gguf_path: &Path, gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> { ) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
...@@ -77,6 +76,17 @@ struct MistralRsEngine { ...@@ -77,6 +76,17 @@ struct MistralRsEngine {
impl MistralRsEngine { impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> { async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let mut hf_token_source = TokenSource::CacheToken;
// We might be trying to download a repo from Hugging Face. See if we have a token.
if !model_path.exists() {
for v_name in HF_TOKEN_VARS {
if env::var(v_name).is_ok() {
tracing::debug!("Using Hugging Face token from {v_name}");
hf_token_source = TokenSource::EnvVar(v_name.to_string());
break;
}
}
}
let loader = if model_path.is_file() { let loader = if model_path.is_file() {
// Load from a GGUF // Load from a GGUF
let Some(model_filename) = model_path.file_name() else { let Some(model_filename) = model_path.file_name() else {
...@@ -117,12 +127,14 @@ impl MistralRsEngine { ...@@ -117,12 +127,14 @@ impl MistralRsEngine {
.build(None)? .build(None)?
}; };
let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
// Paged attention requires cuda // Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") { let paged_attention_config = if cfg!(feature = "cuda") {
Some(PagedAttentionConfig::new( Some(PagedAttentionConfig::new(
Some(32), None, // Block size, default 32
1024, 512, // CPU memory in MiB
MemoryGpuConfig::Utilization(0.9), MemoryGpuConfig::ContextSize(max_seq_len),
)?) )?)
} else { } else {
None None
...@@ -130,13 +142,13 @@ impl MistralRsEngine { ...@@ -130,13 +142,13 @@ impl MistralRsEngine {
// Load, into a Pipeline // Load, into a Pipeline
let pipeline = loader.load_model_from_hf( let pipeline = loader.load_model_from_hf(
None, None,
TokenSource::CacheToken, hf_token_source,
&ModelDType::Auto, &ModelDType::Auto,
&best_device()?, &best_device()?,
false, false,
DeviceMapSetting::Auto(AutoDeviceMapParams::Text { DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
max_seq_len: MAX_SEQ_LEN, max_seq_len,
max_batch_size: MAX_BATCH_SIZE, max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}), }),
None, None,
paged_attention_config, paged_attention_config,
...@@ -157,11 +169,11 @@ impl MistralRsEngine { ...@@ -157,11 +169,11 @@ impl MistralRsEngine {
tracing::debug!("Using mistralrs DefaultScheduler"); tracing::debug!("Using mistralrs DefaultScheduler");
SchedulerConfig::DefaultScheduler { SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here // Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(5).unwrap()), method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
} }
}; };
// Create the MistralRs, which is a runner // Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler); let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
Ok(MistralRsEngine { Ok(MistralRsEngine {
mistralrs: builder.build(), mistralrs: builder.build(),
pipeline, pipeline,
...@@ -225,11 +237,43 @@ impl ...@@ -225,11 +237,43 @@ impl
limit, limit,
); );
let det = SamplingParams::deterministic();
let sampling_params = SamplingParams {
temperature: request
.inner
.temperature
.map(|t| t as f64)
.or(det.temperature),
top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
top_n_logprobs: request
.inner
.top_logprobs
.map(|t| t as usize)
.unwrap_or(det.top_n_logprobs),
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
max_len: request
.inner
.max_completion_tokens
.map(|m| m as usize)
.or(det.max_len),
logits_bias: request
.inner
.logit_bias
.map(to_logit_bias)
.or(det.logits_bias),
// These are not in async-openai yet
top_k: det.top_k,
min_p: det.min_p,
n_choices: 1,
dry_params: det.dry_params,
};
let mistralrs_request = Request::Normal(NormalRequest { let mistralrs_request = Request::Normal(NormalRequest {
messages: RequestMessage::Chat(messages), messages: RequestMessage::Chat(messages),
sampling_params: SamplingParams::deterministic(), sampling_params,
response: tx, response: tx,
return_logprobs: false, return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true, is_streaming: true,
id: self.mistralrs.next_request_id(), id: self.mistralrs.next_request_id(),
constraint: Constraint::None, constraint: Constraint::None,
...@@ -319,3 +363,34 @@ impl ...@@ -319,3 +363,34 @@ impl
Ok(ResponseStream::new(Box::pin(output), ctx)) Ok(ResponseStream::new(Box::pin(output), ctx))
} }
} }
/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
match t {
async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
}
}
/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
let mut out = HashMap::new();
for (key, value) in &lb {
let token_id: u32 = match key.parse() {
Ok(t) => t,
Err(err) => {
tracing::warn!(
"Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
);
return HashMap::new();
}
};
let Some(bias) = value.as_f64() else {
tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
return HashMap::new();
};
out.insert(token_id, bias as f32);
}
out
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment