Commit d99b188d authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(dynamo-run): Upgrade mistral.rs (#97)

- Latest from repo, many improvements
- Support most of the OpenAI request features (temperature, top_p, etc)
- Download models from Hugging Face if necessary
parent e5db9e86
......@@ -638,7 +638,7 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"byteorder",
"candle-kernels",
......@@ -684,7 +684,7 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"bindgen_cuda 0.1.5",
]
......@@ -692,7 +692,7 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"metal",
"once_cell",
......@@ -703,7 +703,7 @@ dependencies = [
[[package]]
name = "candle-nn"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"candle-core",
"candle-metal-kernels",
......@@ -1384,6 +1384,12 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dyn-clone"
version = "1.0.19"
......@@ -2669,6 +2675,19 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]]
name = "inventory"
version = "0.3.20"
......@@ -3164,7 +3183,7 @@ dependencies = [
[[package]]
name = "mistralrs"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"candle-core",
......@@ -3185,7 +3204,7 @@ dependencies = [
[[package]]
name = "mistralrs-core"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"akin",
"anyhow",
......@@ -3213,6 +3232,7 @@ dependencies = [
"image",
"indexmap 2.7.1",
"indicatif",
"interprocess",
"itertools 0.13.0",
"llguidance",
"lrtable",
......@@ -3235,6 +3255,7 @@ dependencies = [
"safetensors",
"schemars",
"serde",
"serde-big-array",
"serde_json",
"serde_plain",
"serde_yaml",
......@@ -3257,7 +3278,7 @@ dependencies = [
[[package]]
name = "mistralrs-paged-attn"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
......@@ -3272,7 +3293,7 @@ dependencies = [
[[package]]
name = "mistralrs-quant"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"bindgen_cuda 0.1.5",
"byteorder",
......@@ -3298,7 +3319,7 @@ dependencies = [
[[package]]
name = "mistralrs-vision"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"candle-core",
"image",
......@@ -4363,6 +4384,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]]
name = "redox_syscall"
version = "0.5.10"
......@@ -4829,6 +4856,15 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
......@@ -6272,6 +6308,12 @@ dependencies = [
"rustix 0.38.44",
]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]]
name = "winapi"
version = "0.2.8"
......
......@@ -633,7 +633,7 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"byteorder",
"candle-kernels",
......@@ -679,7 +679,7 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"bindgen_cuda 0.1.5",
]
......@@ -687,7 +687,7 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"metal",
"once_cell",
......@@ -698,7 +698,7 @@ dependencies = [
[[package]]
name = "candle-nn"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"candle-core",
"candle-metal-kernels",
......@@ -1379,6 +1379,12 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dyn-clone"
version = "1.0.18"
......@@ -2664,6 +2670,19 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]]
name = "inventory"
version = "0.3.20"
......@@ -3139,7 +3158,7 @@ dependencies = [
[[package]]
name = "mistralrs"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"candle-core",
......@@ -3160,7 +3179,7 @@ dependencies = [
[[package]]
name = "mistralrs-core"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"akin",
"anyhow",
......@@ -3188,6 +3207,7 @@ dependencies = [
"image",
"indexmap 2.7.1",
"indicatif",
"interprocess",
"itertools 0.13.0",
"llguidance",
"lrtable",
......@@ -3210,6 +3230,7 @@ dependencies = [
"safetensors",
"schemars",
"serde",
"serde-big-array",
"serde_json",
"serde_plain",
"serde_yaml",
......@@ -3232,7 +3253,7 @@ dependencies = [
[[package]]
name = "mistralrs-paged-attn"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
......@@ -3247,7 +3268,7 @@ dependencies = [
[[package]]
name = "mistralrs-quant"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"bindgen_cuda 0.1.5",
"byteorder",
......@@ -3273,7 +3294,7 @@ dependencies = [
[[package]]
name = "mistralrs-vision"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"candle-core",
"image",
......@@ -4327,6 +4348,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]]
name = "redox_syscall"
version = "0.5.9"
......@@ -4780,6 +4807,15 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
......@@ -6200,6 +6236,12 @@ dependencies = [
"rustix",
]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]]
name = "winapi"
version = "0.2.8"
......
......@@ -28,6 +28,13 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
`cargo build --release --features mistralrs`
## Quickstart
If you have an `HF_TOKEN` environment variable set, this will download Qwen2.5 3B from Hugging Face (6 GiB download) and start it in interactive mode:
```
./target/release/dynamo-run Qwen/Qwen2.5-3B-Instruct
```
## Download a model from Hugging Face
For example one of these should be fast and good quality on almost any machine: https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF
......@@ -40,7 +47,7 @@ For example one of these should be fast and good quality on almost any machine:
*HTTP interface*
`./target/release/dynamo-run in=http --model-path Llama-3.2-1B-Instruct-Q4_K_M.gguf`
`./target/release/dynamo-run in=http Llama-3.2-1B-Instruct-Q4_K_M.gguf`
List the models: `curl localhost:8080/v1/models`
......@@ -63,7 +70,7 @@ dynamo-run in=dyn://llama3B_pool out=mistralrs ~/llm_models/Llama-3.2-3B-Instruc
This will use etcd to auto-discover the model and NATS to talk to it. You can run multiple workers on the same endpoint and it will pick one at random each time.
The `ns/backend/mistralrs` are purely symbolic, pick anything as long as it has three parts, and it matches the other node.
The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node.
Run `dynamo-run --help` for more options.
......
......@@ -20,12 +20,17 @@ use std::str::FromStr;
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about, long_about = None)]
pub struct Flags {
/// Full path to the model, which can be either a GGUF file or a checked out HF repository.
/// For the `echo_full` engine omit the flag.
/// The model. The options depend on the engine.
///
/// The full list - only mistralrs supports all three currently:
/// - Full path to a GGUF file
/// - Full path of a checked out Hugging Face repository containing safetensor files
/// - Name of a Hugging Face repository, e.g 'google/flan-t5-small'. The model will be
/// downloaded and cached.
#[arg(index = 1)]
pub model_path_pos: Option<PathBuf>,
// `--model-path`. The one above is `tio <positional-model-path>`
// `--model-path`. The one above is `dynamo-run <positional-model-path>`
#[arg(long = "model-path")]
pub model_path_flag: Option<PathBuf>,
......
......@@ -83,7 +83,13 @@ pub async fn run(
let model_path = flags
.model_path_pos
.or(flags.model_path_flag)
.and_then(|p| p.canonicalize().ok());
.and_then(|p| {
if p.exists() {
p.canonicalize().ok()
} else {
Some(p)
}
});
// Serve the model under the name provided, or the name of the GGUF file or HF repo.
let model_name = flags
.model_name
......
......@@ -658,7 +658,7 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"byteorder",
"candle-kernels",
......@@ -704,7 +704,7 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"bindgen_cuda 0.1.5",
]
......@@ -712,7 +712,7 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"metal",
"once_cell",
......@@ -723,7 +723,7 @@ dependencies = [
[[package]]
name = "candle-nn"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=fb5cc8c#fb5cc8c0175bc2999ed9fb75b13886aade453dbe"
source = "git+https://github.com/EricLBuehler/candle.git?rev=76819e8#76819e867e6c8464485980677cf188d10e47213f"
dependencies = [
"candle-core",
"candle-metal-kernels",
......@@ -1392,6 +1392,12 @@ dependencies = [
"syn 2.0.98",
]
[[package]]
name = "doctest-file"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dyn-clone"
version = "1.0.18"
......@@ -2715,6 +2721,19 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "interprocess"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d941b405bd2322993887859a8ee6ac9134945a24ec5ec763a8a962fc64dfec2d"
dependencies = [
"doctest-file",
"libc",
"recvmsg",
"widestring",
"windows-sys 0.52.0",
]
[[package]]
name = "inventory"
version = "0.3.20"
......@@ -3205,7 +3224,7 @@ dependencies = [
[[package]]
name = "mistralrs"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"candle-core",
......@@ -3226,7 +3245,7 @@ dependencies = [
[[package]]
name = "mistralrs-core"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"akin",
"anyhow",
......@@ -3254,6 +3273,7 @@ dependencies = [
"image",
"indexmap 2.7.1",
"indicatif",
"interprocess",
"itertools 0.13.0",
"llguidance",
"lrtable",
......@@ -3276,6 +3296,7 @@ dependencies = [
"safetensors",
"schemars",
"serde",
"serde-big-array",
"serde_json",
"serde_plain",
"serde_yaml",
......@@ -3298,7 +3319,7 @@ dependencies = [
[[package]]
name = "mistralrs-paged-attn"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
......@@ -3313,7 +3334,7 @@ dependencies = [
[[package]]
name = "mistralrs-quant"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"bindgen_cuda 0.1.5",
"byteorder",
......@@ -3339,7 +3360,7 @@ dependencies = [
[[package]]
name = "mistralrs-vision"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=5e689c9#5e689c97653cdb1a631ca6fcd53ff447095bd8e5"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=a691154bb#a691154bbe924d8a4717c156ae14a751b6a95980"
dependencies = [
"candle-core",
"image",
......@@ -4432,6 +4453,12 @@ version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430"
[[package]]
name = "recvmsg"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175"
[[package]]
name = "redox_syscall"
version = "0.5.9"
......@@ -4952,6 +4979,15 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]]
name = "serde-pickle"
version = "1.2.0"
......@@ -6393,6 +6429,12 @@ dependencies = [
"rustix",
]
[[package]]
name = "widestring"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311"
[[package]]
name = "winapi"
version = "0.2.8"
......
......@@ -115,7 +115,7 @@ prometheus = { version = "0.13" }
# mistralrs
either = { version = "1.13" }
indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "5e689c9", optional = true }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "a691154bb", optional = true }
# sglang
async_zmq = { version = "0.4.0", optional = true }
......
......@@ -13,7 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::min, num::NonZero, path::Path, sync::Arc};
use std::collections::HashMap;
use std::{cmp::min, env, num::NonZero, path::Path, sync::Arc};
use async_openai::types::FinishReason;
use async_stream::stream;
......@@ -24,7 +25,8 @@ use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, TokenSource,
Pipeline, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens,
TokenSource,
};
use tokio::sync::mpsc::channel;
......@@ -41,15 +43,12 @@ use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine
/// If user does not provide a max_tokens limit prompt+output to this many
const DEFAULT_MAX_TOKENS: i32 = 8192;
/// TODO: tune. Presumably we read it from model's config.json?
const MAX_SEQ_LEN: usize = 4096;
// TODO: tune, maybe implement batching.
const MAX_BATCH_SIZE: usize = 2;
/// TODO: tune
const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 5;
/// The environment variable which can hold the Hugging Face token, if any, in order
const HF_TOKEN_VARS: [&str; 3] = ["HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "HUGGINGFACE_TOKEN"];
pub async fn make_engine(
gguf_path: &Path,
) -> pipeline_error::Result<OpenAIChatCompletionsStreamingEngine> {
......@@ -77,6 +76,17 @@ struct MistralRsEngine {
impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
let mut hf_token_source = TokenSource::CacheToken;
// We might be trying to download a repo from Hugging Face. See if we have a token.
if !model_path.exists() {
for v_name in HF_TOKEN_VARS {
if env::var(v_name).is_ok() {
tracing::debug!("Using Hugging Face token from {v_name}");
hf_token_source = TokenSource::EnvVar(v_name.to_string());
break;
}
}
}
let loader = if model_path.is_file() {
// Load from a GGUF
let Some(model_filename) = model_path.file_name() else {
......@@ -117,12 +127,14 @@ impl MistralRsEngine {
.build(None)?
};
let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
// Paged attention requires cuda
let paged_attention_config = if cfg!(feature = "cuda") {
Some(PagedAttentionConfig::new(
Some(32),
1024,
MemoryGpuConfig::Utilization(0.9),
None, // Block size, default 32
512, // CPU memory in MiB
MemoryGpuConfig::ContextSize(max_seq_len),
)?)
} else {
None
......@@ -130,13 +142,13 @@ impl MistralRsEngine {
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
TokenSource::CacheToken,
hf_token_source,
&ModelDType::Auto,
&best_device()?,
false,
DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
max_seq_len: MAX_SEQ_LEN,
max_batch_size: MAX_BATCH_SIZE,
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}),
None,
paged_attention_config,
......@@ -157,11 +169,11 @@ impl MistralRsEngine {
tracing::debug!("Using mistralrs DefaultScheduler");
SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(5).unwrap()),
method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
}
};
// Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler);
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
Ok(MistralRsEngine {
mistralrs: builder.build(),
pipeline,
......@@ -225,11 +237,43 @@ impl
limit,
);
let det = SamplingParams::deterministic();
let sampling_params = SamplingParams {
temperature: request
.inner
.temperature
.map(|t| t as f64)
.or(det.temperature),
top_p: request.inner.top_p.map(|t| t as f64).or(det.top_p),
top_n_logprobs: request
.inner
.top_logprobs
.map(|t| t as usize)
.unwrap_or(det.top_n_logprobs),
frequency_penalty: request.inner.frequency_penalty.or(det.frequency_penalty),
presence_penalty: request.inner.presence_penalty.or(det.presence_penalty),
stop_toks: request.inner.stop.map(to_stop_tokens).or(det.stop_toks),
max_len: request
.inner
.max_completion_tokens
.map(|m| m as usize)
.or(det.max_len),
logits_bias: request
.inner
.logit_bias
.map(to_logit_bias)
.or(det.logits_bias),
// These are not in async-openai yet
top_k: det.top_k,
min_p: det.min_p,
n_choices: 1,
dry_params: det.dry_params,
};
let mistralrs_request = Request::Normal(NormalRequest {
messages: RequestMessage::Chat(messages),
sampling_params: SamplingParams::deterministic(),
sampling_params,
response: tx,
return_logprobs: false,
return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true,
id: self.mistralrs.next_request_id(),
constraint: Constraint::None,
......@@ -319,3 +363,34 @@ impl
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
/// openai stop tokens to mistralrs stop tokens
fn to_stop_tokens(t: async_openai::types::Stop) -> StopTokens {
match t {
async_openai::types::Stop::String(s) => StopTokens::Seqs(vec![s]),
async_openai::types::Stop::StringArray(v) => StopTokens::Seqs(v),
}
}
/// openai logit bias (strings/json) to mistralrs (u32/f32)
/// I think the input looks like this: {"3721": -100, "17765": 100}
fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
let mut out = HashMap::new();
for (key, value) in &lb {
let token_id: u32 = match key.parse() {
Ok(t) => t,
Err(err) => {
tracing::warn!(
"Unexpected logit_bias map. Key '{key}' is not an int: {lb:?}. {err}."
);
return HashMap::new();
}
};
let Some(bias) = value.as_f64() else {
tracing::warn!("Unexpected logit_bias map. Value '{value}' is not a float: {lb:?}");
return HashMap::new();
};
out.insert(token_id, bias as f32);
}
out
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment