Unverified Commit ceaeba3e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Qwen3, Gemma3 and Llama4 support (#1002)

. New mistralrs and llamacpp version
. mistralrs: Handle Gemma 3 and Llama 4 as vision models
. Update the dynamo-run docs to use Qwen 3
. Our pre-processor now supports Llama 4's newer multi-modal `config.json`
. Upgrade minijinja to handle Qwen 3's prompt template

For Llama 4 we'll need to limit the max seq len. vllm says:
> To serve at least one request with the models's max seq len (10485760), (240.00 GiB KV cache is needed,...

I was able to run Llama 4 with llamacpp and a quantized GGUF, with Dynamo doing the pre-processing.
parent 57402e70
...@@ -672,7 +672,7 @@ dependencies = [ ...@@ -672,7 +672,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4" source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels 0.8.0", "candle-kernels 0.8.0",
...@@ -722,7 +722,7 @@ dependencies = [ ...@@ -722,7 +722,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4" source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
] ]
...@@ -739,7 +739,7 @@ dependencies = [ ...@@ -739,7 +739,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4" source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [ dependencies = [
"metal", "metal",
"once_cell", "once_cell",
...@@ -750,7 +750,7 @@ dependencies = [ ...@@ -750,7 +750,7 @@ dependencies = [
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4" source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [ dependencies = [
"candle-core 0.8.0", "candle-core 0.8.0",
"candle-metal-kernels", "candle-metal-kernels",
...@@ -950,7 +950,7 @@ dependencies = [ ...@@ -950,7 +950,7 @@ dependencies = [
"encode_unicode", "encode_unicode",
"libc", "libc",
"once_cell", "once_cell",
"unicode-width", "unicode-width 0.2.0",
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
...@@ -1118,6 +1118,29 @@ dependencies = [ ...@@ -1118,6 +1118,29 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "cssparser"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c66d1cd8ed61bf80b38432613a7a2f09401ab8d0501110655f8b341484a3e3"
dependencies = [
"cssparser-macros",
"dtoa-short",
"itoa",
"phf",
"smallvec",
]
[[package]]
name = "cssparser-macros"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331"
dependencies = [
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "csv" name = "csv"
version = "1.3.1" version = "1.3.1"
...@@ -1364,9 +1387,9 @@ dependencies = [ ...@@ -1364,9 +1387,9 @@ dependencies = [
[[package]] [[package]]
name = "derivre" name = "derivre"
version = "0.3.1" version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a3c2606b3ffc46f91fd62d954d55659ba9fb391bb673311b70f50daf9c15e49" checksum = "4a605f30e6a1460a323cc4de7bc62dea81df1d9d67eb92194d3a983a8a9601c4"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytemuck", "bytemuck",
...@@ -1456,6 +1479,21 @@ version = "1.0.0" ...@@ -1456,6 +1479,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562" checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dtoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6add3b8cff394282be81f3fc1a0605db594ed69890078ca6e2cab1c408bcf04"
[[package]]
name = "dtoa-short"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd1511a7b6a56299bd043a9c167a6d2bfb37bf84a6dfceaba651168adfb43c87"
dependencies = [
"dtoa",
]
[[package]] [[package]]
name = "dyn-clone" name = "dyn-clone"
version = "1.0.19" version = "1.0.19"
...@@ -1718,6 +1756,12 @@ dependencies = [ ...@@ -1718,6 +1756,12 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "ego-tree"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2972feb8dffe7bc8c5463b1dacda1b0dfbed3710e50f977d965429692d74cd8"
[[package]] [[package]]
name = "either" name = "either"
version = "1.15.0" version = "1.15.0"
...@@ -2066,6 +2110,16 @@ version = "0.3.3" ...@@ -2066,6 +2110,16 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
[[package]]
name = "futf"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843"
dependencies = [
"mac",
"new_debug_unreachable",
]
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.31" version = "0.3.31"
...@@ -2161,6 +2215,15 @@ dependencies = [ ...@@ -2161,6 +2215,15 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "galil-seiferas" name = "galil-seiferas"
version = "0.1.5" version = "0.1.5"
...@@ -2418,6 +2481,15 @@ dependencies = [ ...@@ -2418,6 +2481,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "getopts"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
dependencies = [
"unicode-width 0.1.14",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.16" version = "0.2.16"
...@@ -2599,6 +2671,42 @@ dependencies = [ ...@@ -2599,6 +2671,42 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "html2text"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1637acec3b965bab873352189d887b12c87b4f8d7571f4d185e796be5654ad8"
dependencies = [
"html5ever 0.31.0",
"tendril",
"thiserror 2.0.12",
"unicode-width 0.2.0",
]
[[package]]
name = "html5ever"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b7410cae13cbc75623c98ac4cbfd1f0bedddf3227afc24f370cf0f50a44a11c"
dependencies = [
"log",
"mac",
"markup5ever 0.14.1",
"match_token",
]
[[package]]
name = "html5ever"
version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953cbbe631aae7fc0a112702ad5d3aaf09da38beaf45ea84610d6e1c358f569c"
dependencies = [
"log",
"mac",
"markup5ever 0.16.1",
"match_token",
]
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.0" version = "0.2.0"
...@@ -3011,7 +3119,7 @@ dependencies = [ ...@@ -3011,7 +3119,7 @@ dependencies = [
"number_prefix", "number_prefix",
"portable-atomic", "portable-atomic",
"rayon", "rayon",
"unicode-width", "unicode-width 0.2.0",
"web-time", "web-time",
] ]
...@@ -3316,9 +3424,9 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" ...@@ -3316,9 +3424,9 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]] [[package]]
name = "llama-cpp-2" name = "llama-cpp-2"
version = "0.1.102" version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768" checksum = "401c708926326b1ee410735dc348882c73deeab78f1f89ff2c9caf148356feb4"
dependencies = [ dependencies = [
"enumflags2", "enumflags2",
"llama-cpp-sys-2", "llama-cpp-sys-2",
...@@ -3329,9 +3437,9 @@ dependencies = [ ...@@ -3329,9 +3437,9 @@ dependencies = [
[[package]] [[package]]
name = "llama-cpp-sys-2" name = "llama-cpp-sys-2"
version = "0.1.102" version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e" checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [ dependencies = [
"bindgen", "bindgen",
"cc", "cc",
...@@ -3343,8 +3451,9 @@ dependencies = [ ...@@ -3343,8 +3451,9 @@ dependencies = [
[[package]] [[package]]
name = "llguidance" name = "llguidance"
version = "0.7.0" version = "0.7.19"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c50686e6724ff55e184447dbc775b94f28f64e5061882b75c84b12b9c1d613"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"derivre", "derivre",
...@@ -3352,7 +3461,7 @@ dependencies = [ ...@@ -3352,7 +3461,7 @@ dependencies = [
"regex-syntax 0.8.5", "regex-syntax 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"toktrie 0.7.0", "toktrie 0.7.19",
] ]
[[package]] [[package]]
...@@ -3410,6 +3519,12 @@ dependencies = [ ...@@ -3410,6 +3519,12 @@ dependencies = [
"vob", "vob",
] ]
[[package]]
name = "mac"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
[[package]] [[package]]
name = "macro_rules_attribute" name = "macro_rules_attribute"
version = "0.2.0" version = "0.2.0"
...@@ -3435,6 +3550,42 @@ dependencies = [ ...@@ -3435,6 +3550,42 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "markup5ever"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7a7213d12e1864c0f002f52c2923d4556935a43dec5e71355c2760e0f6e7a18"
dependencies = [
"log",
"phf",
"phf_codegen",
"string_cache",
"string_cache_codegen",
"tendril",
]
[[package]]
name = "markup5ever"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a8096766c229e8c88a3900c9b44b7e06aa7f7343cc229158c3e58ef8f9973a"
dependencies = [
"log",
"tendril",
"web_atoms",
]
[[package]]
name = "match_token"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a9689d8d44bf9964484516275f5cd4c9b59457a6940c1d5d0ecbb94510a36b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "matchers" name = "matchers"
version = "0.1.0" version = "0.1.0"
...@@ -3549,9 +3700,9 @@ dependencies = [ ...@@ -3549,9 +3700,9 @@ dependencies = [
[[package]] [[package]]
name = "minijinja" name = "minijinja"
version = "2.9.0" version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98642a6dfca91122779a307b77cd07a4aa951fbe32232aaf5bad9febc66be754" checksum = "dd72e8b4e42274540edabec853f607c015c73436159b06c39c7af85a20433155"
dependencies = [ dependencies = [
"memo-map", "memo-map",
"self_cell", "self_cell",
...@@ -3561,9 +3712,9 @@ dependencies = [ ...@@ -3561,9 +3712,9 @@ dependencies = [
[[package]] [[package]]
name = "minijinja-contrib" name = "minijinja-contrib"
version = "2.9.0" version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd4a0f6e171c7bb92ed2caf446fa3de4e26561cea1d97085103e9cb42359dd59" checksum = "457f85f9c4c5b17d11fcf9bbe7c0dbba64843c5ee040005956f1a510b6679fe2"
dependencies = [ dependencies = [
"minijinja", "minijinja",
"serde", "serde",
...@@ -3641,8 +3792,8 @@ dependencies = [ ...@@ -3641,8 +3792,8 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs" name = "mistralrs"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"candle-core 0.8.0", "candle-core 0.8.0",
...@@ -3662,8 +3813,8 @@ dependencies = [ ...@@ -3662,8 +3813,8 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-core" name = "mistralrs-core"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [ dependencies = [
"akin", "akin",
"anyhow", "anyhow",
...@@ -3688,6 +3839,7 @@ dependencies = [ ...@@ -3688,6 +3839,7 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"half", "half",
"hf-hub", "hf-hub",
"html2text",
"image", "image",
"indexmap 2.9.0", "indexmap 2.9.0",
"indicatif", "indicatif",
...@@ -3703,6 +3855,7 @@ dependencies = [ ...@@ -3703,6 +3855,7 @@ dependencies = [
"mistralrs-vision", "mistralrs-vision",
"objc", "objc",
"once_cell", "once_cell",
"ordered-float",
"radix_trie", "radix_trie",
"rand 0.9.1", "rand 0.9.1",
"rand_isaac", "rand_isaac",
...@@ -3713,6 +3866,7 @@ dependencies = [ ...@@ -3713,6 +3866,7 @@ dependencies = [
"rustc-hash 2.1.1", "rustc-hash 2.1.1",
"safetensors", "safetensors",
"schemars", "schemars",
"scraper",
"serde", "serde",
"serde-big-array", "serde-big-array",
"serde_json", "serde_json",
...@@ -3724,11 +3878,12 @@ dependencies = [ ...@@ -3724,11 +3878,12 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-rayon", "tokio-rayon",
"toktrie_hf_tokenizers 0.7.0", "toktrie_hf_tokenizers 0.7.19",
"toml", "toml",
"tqdm", "tqdm",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"urlencoding",
"uuid 1.16.0", "uuid 1.16.0",
"variantly", "variantly",
"vob", "vob",
...@@ -3736,8 +3891,8 @@ dependencies = [ ...@@ -3736,8 +3891,8 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-paged-attn" name = "mistralrs-paged-attn"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen_cuda 0.1.6", "bindgen_cuda 0.1.6",
...@@ -3751,8 +3906,8 @@ dependencies = [ ...@@ -3751,8 +3906,8 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-quant" name = "mistralrs-quant"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [ dependencies = [
"bindgen_cuda 0.1.5", "bindgen_cuda 0.1.5",
"byteorder", "byteorder",
...@@ -3778,11 +3933,12 @@ dependencies = [ ...@@ -3778,11 +3933,12 @@ dependencies = [
[[package]] [[package]]
name = "mistralrs-vision" name = "mistralrs-vision"
version = "0.4.0" version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646" source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [ dependencies = [
"candle-core 0.8.0", "candle-core 0.8.0",
"image", "image",
"rayon",
] ]
[[package]] [[package]]
...@@ -3863,6 +4019,12 @@ dependencies = [ ...@@ -3863,6 +4019,12 @@ dependencies = [
"winapi 0.3.9", "winapi 0.3.9",
] ]
[[package]]
name = "new_debug_unreachable"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]] [[package]]
name = "nibble_vec" name = "nibble_vec"
version = "0.1.0" version = "0.1.0"
...@@ -4154,6 +4316,15 @@ version = "0.2.0" ...@@ -4154,6 +4316,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordered-float"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "overload" name = "overload"
version = "0.1.1" version = "0.1.1"
...@@ -4178,7 +4349,7 @@ checksum = "b915f831b85d984193fdc3d3611505871dc139b2534530fa01c1a6a6707b6723" ...@@ -4178,7 +4349,7 @@ checksum = "b915f831b85d984193fdc3d3611505871dc139b2534530fa01c1a6a6707b6723"
dependencies = [ dependencies = [
"bytecount", "bytecount",
"fnv", "fnv",
"unicode-width", "unicode-width 0.2.0",
] ]
[[package]] [[package]]
...@@ -4309,6 +4480,58 @@ dependencies = [ ...@@ -4309,6 +4480,58 @@ dependencies = [
"indexmap 2.9.0", "indexmap 2.9.0",
] ]
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_macros",
"phf_shared",
]
[[package]]
name = "phf_codegen"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a"
dependencies = [
"phf_generator",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand 0.8.5",
]
[[package]]
name = "phf_macros"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]] [[package]]
name = "pin-project" name = "pin-project"
version = "1.1.10" version = "1.1.10"
...@@ -4400,6 +4623,12 @@ dependencies = [ ...@@ -4400,6 +4623,12 @@ dependencies = [
"zerocopy", "zerocopy",
] ]
[[package]]
name = "precomputed-hash"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.32" version = "0.2.32"
...@@ -5423,6 +5652,21 @@ version = "1.2.0" ...@@ -5423,6 +5652,21 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scraper"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "527e65d9d888567588db4c12da1087598d0f6f8b346cc2c5abc91f05fc2dffe2"
dependencies = [
"cssparser",
"ego-tree",
"getopts",
"html5ever 0.29.1",
"precomputed-hash",
"selectors",
"tendril",
]
[[package]] [[package]]
name = "secrecy" name = "secrecy"
version = "0.10.3" version = "0.10.3"
...@@ -5469,6 +5713,25 @@ dependencies = [ ...@@ -5469,6 +5713,25 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "selectors"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd568a4c9bb598e291a08244a5c1f5a8a6650bee243b5b0f8dbb3d9cc1d87fe8"
dependencies = [
"bitflags 2.9.0",
"cssparser",
"derive_more",
"fxhash",
"log",
"new_debug_unreachable",
"phf",
"phf_codegen",
"precomputed-hash",
"servo_arc",
"smallvec",
]
[[package]] [[package]]
name = "self_cell" name = "self_cell"
version = "1.2.0" version = "1.2.0"
...@@ -5639,6 +5902,15 @@ dependencies = [ ...@@ -5639,6 +5902,15 @@ dependencies = [
"unsafe-libyaml", "unsafe-libyaml",
] ]
[[package]]
name = "servo_arc"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae65c4249478a2647db249fb43e23cec56a2c8974a427e7bd8cb5a1d0964921a"
dependencies = [
"stable_deref_trait",
]
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
...@@ -5735,6 +6007,12 @@ version = "2.7.0" ...@@ -5735,6 +6007,12 @@ version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.9" version = "0.4.9"
...@@ -5811,6 +6089,31 @@ version = "1.2.0" ...@@ -5811,6 +6089,31 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "string_cache"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf776ba3fa74f83bf4b63c3dcbbf82173db2632ed8452cb2d891d33f459de70f"
dependencies = [
"new_debug_unreachable",
"parking_lot",
"phf_shared",
"precomputed-hash",
"serde",
]
[[package]]
name = "string_cache_codegen"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c711928715f1fe0fe509c53b43e993a9a557babc2d0a3567d0a3006f1ac931a0"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
]
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.10.0" version = "0.10.0"
...@@ -6049,6 +6352,17 @@ dependencies = [ ...@@ -6049,6 +6352,17 @@ dependencies = [
"windows-sys 0.59.0", "windows-sys 0.59.0",
] ]
[[package]]
name = "tendril"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0"
dependencies = [
"futf",
"mac",
"utf-8",
]
[[package]] [[package]]
name = "terminal_size" name = "terminal_size"
version = "0.4.2" version = "0.4.2"
...@@ -6320,8 +6634,9 @@ dependencies = [ ...@@ -6320,8 +6634,9 @@ dependencies = [
[[package]] [[package]]
name = "toktrie" name = "toktrie"
version = "0.7.0" version = "0.7.19"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69200397314cd578aa0623a5161342c44877ac2ebcea1e9d20e2f65120e81e8d"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytemuck", "bytemuck",
...@@ -6346,15 +6661,16 @@ dependencies = [ ...@@ -6346,15 +6661,16 @@ dependencies = [
[[package]] [[package]]
name = "toktrie_hf_tokenizers" name = "toktrie_hf_tokenizers"
version = "0.7.0" version = "0.7.19"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c03d846b7e3cd072e955583da31732076eed2f7be6881d7605fd4c3e5a8247"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"log", "log",
"serde", "serde",
"serde_json", "serde_json",
"tokenizers", "tokenizers",
"toktrie 0.7.0", "toktrie 0.7.19",
] ]
[[package]] [[package]]
...@@ -6694,6 +7010,12 @@ version = "1.12.0" ...@@ -6694,6 +7010,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]] [[package]]
name = "unicode-width" name = "unicode-width"
version = "0.2.0" version = "0.2.0"
...@@ -6755,6 +7077,18 @@ dependencies = [ ...@@ -6755,6 +7077,18 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "utf16_iter" name = "utf16_iter"
version = "1.0.5" version = "1.0.5"
...@@ -7011,6 +7345,18 @@ dependencies = [ ...@@ -7011,6 +7345,18 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "web_atoms"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b9c5f0bc545ea3b20b423e33b9b457764de0b3730cd957f6c6aa6c301785f6e"
dependencies = [
"phf",
"phf_codegen",
"string_cache",
"string_cache_codegen",
]
[[package]] [[package]]
name = "webpki-roots" name = "webpki-roots"
version = "0.26.8" version = "0.26.8"
......
...@@ -26,7 +26,7 @@ Usage: ...@@ -26,7 +26,7 @@ Usage:
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin] dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]
``` ```
Example: `dynamo run Qwen/Qwen2.5-3B-Instruct`. Example: `dynamo run Qwen/Qwen3-0.6B`
Set environment variable `DYN_LOG` to adjust logging level, e.g. `export DYN_LOG=debug`. It has the same syntax as `RUST_LOG`, ask AI for details. Set environment variable `DYN_LOG` to adjust logging level, e.g. `export DYN_LOG=debug`. It has the same syntax as `RUST_LOG`, ask AI for details.
...@@ -38,9 +38,9 @@ The vllm and sglang engines require [etcd](https://etcd.io/) and [nats](https:// ...@@ -38,9 +38,9 @@ The vllm and sglang engines require [etcd](https://etcd.io/) and [nats](https://
### Use model from Hugging Face ### Use model from Hugging Face
This will automatically download Qwen2.5 3B from Hugging Face (6 GiB download) and start it in interactive text mode: This will automatically download Qwen3 4B from Hugging Face (16 GiB download) and start it in interactive text mode:
``` ```
dynamo run out=vllm Qwen/Qwen2.5-3B-Instruct dynamo run out=vllm Qwen/Qwen3-4B
``` ```
General format for HF download: General format for HF download:
...@@ -65,12 +65,12 @@ curl -L -o Llama-3.2-3B-Instruct-Q4_K_M.gguf "https://huggingface.co/bartowski/L ...@@ -65,12 +65,12 @@ curl -L -o Llama-3.2-3B-Instruct-Q4_K_M.gguf "https://huggingface.co/bartowski/L
#### Run model from local file #### Run model from local file
**Text interface** **Text interface**
``` ```
dynamo run out=vllm Llama-3.2-3B-Instruct-Q4_K_M.gguf # or path to a Hugging Face repo checkout instead of the GGUF dynamo run Llama-3.2-3B-Instruct-Q4_K_M.gguf # or path to a Hugging Face repo checkout instead of the GGUF
``` ```
**HTTP interface** **HTTP interface**
``` ```
dynamo run in=http out=vllm Llama-3.2-3B-Instruct-Q4_K_M.gguf dynamo run in=http out=mistralrs Llama-3.2-3B-Instruct-Q4_K_M.gguf
``` ```
**List the models** **List the models**
...@@ -94,7 +94,7 @@ You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstrea ...@@ -94,7 +94,7 @@ You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstrea
OpenAI compliant HTTP server, optional pre-processing, worker discovery. OpenAI compliant HTTP server, optional pre-processing, worker discovery.
``` ```
dynamo run in=http out=dyn://llama3B_pool dynamo-run in=http out=dyn://llama3B_pool
``` ```
**Node 2:** **Node 2:**
...@@ -109,11 +109,11 @@ This will use etcd to auto-discover the model and NATS to talk to it. You can ru ...@@ -109,11 +109,11 @@ This will use etcd to auto-discover the model and NATS to talk to it. You can ru
The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node. The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node.
Run `dynamo run --help` for more options. Run `dynamo-run --help` for more options.
## Full usage details ## Full usage details
`dynamo-run` is what `dynamo run` executes. It is an example of what you can build in Rust with the `dynamo-llm` and `dynamo-runtime`. The following guide demonstrates how you can build from source with all the features. `dynamo-run` is what `dynamo run` executes. It is also an example of what you can build in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide demonstrates how you can build from source with all the features.
### Setup ### Setup
...@@ -181,13 +181,13 @@ Build with `--release` for a smaller binary and better performance, but longer b ...@@ -181,13 +181,13 @@ Build with `--release` for a smaller binary and better performance, but longer b
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a pure Rust engine that is fast to run, fast to load, supports GGUF as well as safetensors, and runs well on CPU as well as GPU. For those reasons it is the default engine. [mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a pure Rust engine that is fast to run, fast to load, supports GGUF as well as safetensors, and runs well on CPU as well as GPU. For those reasons it is the default engine.
``` ```
dynamo-run Qwen/Qwen2.5-3B-Instruct dynamo-run Qwen/Qwen3-4B
``` ```
is equivalent to is equivalent to
``` ```
dynamo-run in=text out=mistralrs Qwen/Qwen2.5-3B-Instruct dynamo-run in=text out=mistralrs Qwen/Qwen3-4B
``` ```
If you have multiple GPUs, mistral.rs does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it. If you have multiple GPUs, mistral.rs does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
...@@ -204,7 +204,14 @@ cargo build --features llamacpp[,cuda|metal|vulkan] -p dynamo-run ...@@ -204,7 +204,14 @@ cargo build --features llamacpp[,cuda|metal|vulkan] -p dynamo-run
dynamo-run out=llamacpp ~/llms/Llama-3.2-3B-Instruct-Q6_K.gguf dynamo-run out=llamacpp ~/llms/Llama-3.2-3B-Instruct-Q6_K.gguf
``` ```
llamacpp is best for single-GPU inference with a quantized GGUF model file. Note that in some cases we are unable to extract the tokenizer from the GGUF, and so a Hugging Face checkout of a matching model must also be passed. Dynamo will use the weights from the GGUF and the pre-processor (`tokenizer.json`, etc) from the `--model-config`:
```
dynamo-run out=llamacpp ~/llms/gemma-3-1b-it-q4_0.gguf --model-config ~/llms/gemma-3-1b-it
dynamo-run out=llamacpp ~/llms/Llama-4-Scout-17B-16E-Instruct-UD-IQ1_S.gguf --model-config ~/llms/Llama-4-Scout-17B-16E-Instruct
```
If you have multiple GPUs, llama.cpp does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
### sglang ### sglang
...@@ -233,9 +240,13 @@ To pass extra arguments to the sglang engine see *Extra engine arguments* below. ...@@ -233,9 +240,13 @@ To pass extra arguments to the sglang engine see *Extra engine arguments* below.
**Multi-GPU** **Multi-GPU**
Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`. To specify which GPU to start from pass `--base-gpu-id <num>`. Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`.
```
dynamo-run out=sglang ~/llms/Llama-4-Scout-17B-16E-Instruct/ --tensor-parallel-size 8
```
For example on a shared eight GPU machine where GPUs 0-3 are already in use: To specify which GPU to start from pass `--base-gpu-id <num>`, for example on a shared eight GPU machine where GPUs 0-3 are already in use:
``` ```
dynamo-run out=sglang <model> --tensor-parallel-size 4 --base-gpu-id 4 dynamo-run out=sglang <model> --tensor-parallel-size 4 --base-gpu-id 4
``` ```
......
...@@ -16,10 +16,6 @@ pub use dynamo_llm::request_template::RequestTemplate; ...@@ -16,10 +16,6 @@ pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output}; pub use opt::{Input, Output};
mod subprocess; mod subprocess;
/// When `in=text` the user doesn't need to know the model name, and doesn't need to provide it on
/// the command line. Hence it's optional, and defaults to this.
const INVISIBLE_MODEL_NAME: &str = "dynamo-run";
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2); const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
/// How we identify a python string endpoint /// How we identify a python string endpoint
...@@ -64,15 +60,10 @@ pub async fn run( ...@@ -64,15 +60,10 @@ pub async fn run(
_ => { _ => {
match &maybe_path { match &maybe_path {
Some(model_path) => { Some(model_path) => {
let maybe_model_name = if in_opt == Input::Text {
Some(INVISIBLE_MODEL_NAME.to_string())
} else {
flags.model_name.clone()
};
LocalModel::prepare( LocalModel::prepare(
model_path.to_str().context("Invalid UTF-8 in model path")?, model_path.to_str().context("Invalid UTF-8 in model path")?,
flags.model_config.as_deref(), flags.model_config.as_deref(),
maybe_model_name, flags.model_name.clone(),
) )
.await? .await?
} }
...@@ -121,7 +112,7 @@ pub async fn run( ...@@ -121,7 +112,7 @@ pub async fn run(
} }
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => EngineConfig::StaticFull { Output::MistralRs => EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(local_model.path()).await?, engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model), model: Box::new(local_model),
}, },
Output::SgLang => { Output::SgLang => {
......
...@@ -2511,9 +2511,9 @@ dependencies = [ ...@@ -2511,9 +2511,9 @@ dependencies = [
[[package]] [[package]]
name = "minijinja" name = "minijinja"
version = "2.9.0" version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98642a6dfca91122779a307b77cd07a4aa951fbe32232aaf5bad9febc66be754" checksum = "dd72e8b4e42274540edabec853f607c015c73436159b06c39c7af85a20433155"
dependencies = [ dependencies = [
"memo-map", "memo-map",
"self_cell", "self_cell",
...@@ -2522,9 +2522,9 @@ dependencies = [ ...@@ -2522,9 +2522,9 @@ dependencies = [
[[package]] [[package]]
name = "minijinja-contrib" name = "minijinja-contrib"
version = "2.9.0" version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd4a0f6e171c7bb92ed2caf446fa3de4e26561cea1d97085103e9cb42359dd59" checksum = "457f85f9c4c5b17d11fcf9bbe7c0dbba64843c5ee040005956f1a510b6679fe2"
dependencies = [ dependencies = [
"minijinja", "minijinja",
"serde", "serde",
......
...@@ -38,4 +38,4 @@ async-stream = { workspace = true } ...@@ -38,4 +38,4 @@ async-stream = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
llama-cpp-2 = { version = "0.1.102" } llama-cpp-2 = { version = "0.1.103" }
...@@ -228,9 +228,8 @@ fn run_request( ...@@ -228,9 +228,8 @@ fn run_request(
limit, limit,
); );
// create a llama_batch with size 512
// we use this object to submit token data for decoding // we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1); let mut batch = LlamaBatch::new(std::cmp::max(512, max_output_tokens as usize), 1);
let last_index: i32 = (tokens_list.len() - 1) as i32; let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) { for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt // llama_decode will output logits only for the last token of the prompt
......
...@@ -40,7 +40,7 @@ async-trait = { workspace = true } ...@@ -40,7 +40,7 @@ async-trait = { workspace = true }
candle-core = { version = "0.8.0" } candle-core = { version = "0.8.0" }
either = { workspace = true } either = { workspace = true }
indexmap = { version = "2.6" } indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "aaafc2ef" } mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "ebd50e35e" }
serde_json = { workspace = true } serde_json = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
use std::collections::HashMap; use std::collections::HashMap;
use std::{num::NonZero, path::Path, sync::Arc}; use std::{num::NonZero, sync::Arc};
use async_openai::types::FinishReason; use async_openai::types::FinishReason;
use async_stream::stream; use async_stream::stream;
...@@ -23,9 +23,10 @@ use either::Either; ...@@ -23,9 +23,10 @@ use either::Either;
use indexmap::IndexMap; use indexmap::IndexMap;
use mistralrs::{ use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting, AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder, GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig, ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource, Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
}; };
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
...@@ -40,6 +41,7 @@ use dynamo_llm::protocols::openai::{ ...@@ -40,6 +41,7 @@ use dynamo_llm::protocols::openai::{
}; };
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine}; use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
use dynamo_llm::LocalModel;
/// How many requests mistral will run at once in the paged attention scheduler. /// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this. /// It actually runs 1 fewer than this.
...@@ -51,8 +53,11 @@ const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10; ...@@ -51,8 +53,11 @@ const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;
/// finish_reason=stop and no tokens for one of the requests. /// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false; const EXP_ENABLE_PAGED_ATTENTION: bool = false;
pub async fn make_engine(gguf_path: &Path) -> pipeline_error::Result<Arc<dyn StreamingEngine>> { /// Initial message we send to mistral.rs to warm it up. We may not need this.
let engine = MistralRsEngine::new(gguf_path).await?; const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";
pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
let engine = MistralRsEngine::new(model).await?;
let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine)); let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
Ok(engine) Ok(engine)
} }
...@@ -74,7 +79,14 @@ struct MistralRsEngine { ...@@ -74,7 +79,14 @@ struct MistralRsEngine {
} }
impl MistralRsEngine { impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> { async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
let model_path = model.path();
// Name some None's for clarity
let chat_template = None;
let tokenizer_json = None;
let no_kv_cache = false;
let jinja_explicit = None;
let display_name = model.display_name();
let loader = if model_path.is_file() { let loader = if model_path.is_file() {
// Load from a GGUF // Load from a GGUF
let Some(model_filename) = model_path.file_name() else { let Some(model_filename) = model_path.file_name() else {
...@@ -85,7 +97,7 @@ impl MistralRsEngine { ...@@ -85,7 +97,7 @@ impl MistralRsEngine {
}; };
GGUFLoaderBuilder::new( GGUFLoaderBuilder::new(
None, chat_template,
None, None,
model_dir.display().to_string(), model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()], vec![model_filename.to_string_lossy().into_owned()],
...@@ -93,24 +105,35 @@ impl MistralRsEngine { ...@@ -93,24 +105,35 @@ impl MistralRsEngine {
prompt_chunksize: None, prompt_chunksize: None,
topology: None, topology: None,
}, },
no_kv_cache,
jinja_explicit,
) )
.build() .build()
} else if is_vision_model(display_name) {
let vlt = if is_gemma3(display_name) {
VisionLoaderType::Gemma3
} else if is_llama4(display_name) {
VisionLoaderType::Llama4
} else {
panic!("Unsupported vision model {display_name}");
};
VisionLoaderBuilder::new(
VisionSpecificConfig::default(),
chat_template,
tokenizer_json,
Some(model_path.display().to_string()),
jinja_explicit,
)
.build(vlt)
} else { } else {
// Load from a HF repo dir // Load from a HF repo dir
NormalLoaderBuilder::new( NormalLoaderBuilder::new(
NormalSpecificConfig { NormalSpecificConfig::default(),
use_flash_attn: false, chat_template,
prompt_chunksize: None, tokenizer_json,
topology: None,
organization: Default::default(),
write_uqff: None,
from_uqff: None,
imatrix: None,
calibration_file: None,
},
None,
None,
Some(model_path.display().to_string()), Some(model_path.display().to_string()),
no_kv_cache,
jinja_explicit,
) )
.build(None)? .build(None)?
}; };
...@@ -127,6 +150,21 @@ impl MistralRsEngine { ...@@ -127,6 +150,21 @@ impl MistralRsEngine {
} else { } else {
None None
}; };
let device_map_params = if is_vision_model(model.display_name()) {
AutoDeviceMapParams::Vision {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
max_image_shape: (0, 0),
max_num_images: 0,
}
} else {
AutoDeviceMapParams::Text {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}
};
// Load, into a Pipeline // Load, into a Pipeline
let pipeline = loader.load_model_from_hf( let pipeline = loader.load_model_from_hf(
None, None,
...@@ -134,11 +172,12 @@ impl MistralRsEngine { ...@@ -134,11 +172,12 @@ impl MistralRsEngine {
&ModelDType::Auto, &ModelDType::Auto,
&best_device()?, &best_device()?,
false, false,
DeviceMapSetting::Auto(AutoDeviceMapParams::Text { DeviceMapSetting::Auto(device_map_params),
max_seq_len, if is_llama4(display_name) {
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE, Some(IsqType::Q4K)
}), } else {
None, None
},
paged_attention_config, paged_attention_config,
)?; )?;
let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION { let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
...@@ -154,14 +193,21 @@ impl MistralRsEngine { ...@@ -154,14 +193,21 @@ impl MistralRsEngine {
config, config,
} }
} else { } else {
tracing::debug!("Using mistralrs DefaultScheduler");
SchedulerConfig::DefaultScheduler { SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here // Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()), method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
} }
}; };
// Create the MistralRs, which is a runner // Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16); let throughput_logging = false;
let search_embedding_model = None;
let builder = MistralRsBuilder::new(
pipeline.clone(),
scheduler,
throughput_logging,
search_embedding_model,
)
.with_prefix_cache_n(16);
let engine = MistralRsEngine { let engine = MistralRsEngine {
mistralrs: builder.build(), mistralrs: builder.build(),
}; };
...@@ -174,21 +220,27 @@ impl MistralRsEngine { ...@@ -174,21 +220,27 @@ impl MistralRsEngine {
let request_id = engine.mistralrs.next_request_id(); let request_id = engine.mistralrs.next_request_id();
let warmup_request = Request::Normal(NormalRequest { let warmup_request = Request::Normal(NormalRequest {
id: request_id, id: request_id,
messages: RequestMessage::Chat(vec![IndexMap::from([ messages: RequestMessage::Chat {
("role".to_string(), Either::Left("user".to_string())), messages: vec![IndexMap::from([
("content".to_string(), Either::Left("test".to_string())), ("role".to_string(), Either::Left("user".to_string())),
])]), (
"content".to_string(),
Either::Left(WARMUP_MESSAGE.to_string()),
),
])],
enable_thinking: Some(false),
},
sampling_params: SamplingParams::deterministic(), sampling_params: SamplingParams::deterministic(),
response: tx, response: tx,
return_logprobs: false, return_logprobs: false,
is_streaming: false, is_streaming: false,
constraint: Constraint::None, constraint: Constraint::None,
suffix: None, suffix: None,
adapters: None,
tools: None, tools: None,
tool_choice: None, tool_choice: None,
logits_processors: None, logits_processors: None,
return_raw_logits: false, return_raw_logits: false,
web_search_options: None,
}); });
// Send warmup request and consume response // Send warmup request and consume response
...@@ -285,18 +337,21 @@ impl ...@@ -285,18 +337,21 @@ impl
let request_id = self.mistralrs.next_request_id(); let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest { let mistralrs_request = Request::Normal(NormalRequest {
id: request_id, id: request_id,
messages: RequestMessage::Chat(messages), messages: RequestMessage::Chat {
messages,
enable_thinking: None,
},
sampling_params, sampling_params,
response: tx, response: tx,
return_logprobs: request.inner.logprobs.unwrap_or_default(), return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true, is_streaming: true,
constraint: Constraint::None, constraint: Constraint::None,
suffix: None, suffix: None,
adapters: None,
tools: None, tools: None,
tool_choice: None, tool_choice: None,
logits_processors: None, logits_processors: None,
return_raw_logits: false, return_raw_logits: false,
web_search_options: None,
}); });
self.mistralrs.get_sender()?.send(mistralrs_request).await?; self.mistralrs.get_sender()?.send(mistralrs_request).await?;
...@@ -477,11 +532,11 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -477,11 +532,11 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
is_streaming: true, is_streaming: true,
constraint: Constraint::None, constraint: Constraint::None,
suffix: None, suffix: None,
adapters: None,
tools: None, tools: None,
tool_choice: None, tool_choice: None,
logits_processors: None, logits_processors: None,
return_raw_logits: false, return_raw_logits: false,
web_search_options: None,
}); });
self.mistralrs.get_sender()?.send(mistralrs_request).await?; self.mistralrs.get_sender()?.send(mistralrs_request).await?;
...@@ -533,3 +588,15 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -533,3 +588,15 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
Ok(ResponseStream::new(Box::pin(output), ctx)) Ok(ResponseStream::new(Box::pin(output), ctx))
} }
} }
fn is_vision_model(s: &str) -> bool {
is_gemma3(s) || is_llama4(s)
}
fn is_gemma3(s: &str) -> bool {
s.to_lowercase().contains("gemma-3")
}
fn is_llama4(s: &str) -> bool {
s.to_lowercase().contains("llama-4")
}
...@@ -99,8 +99,8 @@ toktrie_hf_tokenizers = { version = "0.6.28" } ...@@ -99,8 +99,8 @@ toktrie_hf_tokenizers = { version = "0.6.28" }
bs62 = { version = "0.1" } bs62 = { version = "0.1" }
erased-serde = { version = "0.4" } erased-serde = { version = "0.4" }
itertools = { version = "0.14.0" } itertools = { version = "0.14.0" }
minijinja = { version = "2.3.1", features = ["loader"] } minijinja = { version = "2.10.2", features = ["loader"] }
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] } minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
# GGUF # GGUF
ggus = "0.4.0" ggus = "0.4.0"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs // Adapted from mistral.rs
// //
...@@ -45,6 +33,7 @@ use strum::EnumString; ...@@ -45,6 +33,7 @@ use strum::EnumString;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
pub(crate) use content::Content; pub(crate) use content::Content;
pub(crate) use gguf_metadata::ContentConfig; pub(crate) use gguf_metadata::ContentConfig;
pub use gguf_metadata::ModelConfigLike;
pub(crate) use gguf_tokenizer::convert_gguf_to_hf_tokenizer; pub(crate) use gguf_tokenizer::convert_gguf_to_hf_tokenizer;
use std::str::FromStr; use std::str::FromStr;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs // Adapted from mistral.rs
// //
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs // Adapted from mistral.rs
// //
...@@ -46,6 +34,16 @@ use tracing::warn; ...@@ -46,6 +34,16 @@ use tracing::warn;
use crate::gguf::Content; use crate::gguf::Content;
pub trait ModelConfigLike {
fn max_seq_len(&self) -> usize;
fn num_layers(&self) -> usize;
fn hidden_size(&self) -> usize;
fn num_kv_heads(&self) -> usize;
fn num_attn_heads(&self) -> usize;
fn k_head_dim(&self) -> usize;
fn v_head_dim(&self) -> usize;
}
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
pub struct ContentConfig { pub struct ContentConfig {
...@@ -87,28 +85,27 @@ impl From<&Content> for ContentConfig { ...@@ -87,28 +85,27 @@ impl From<&Content> for ContentConfig {
} }
} }
#[allow(dead_code)] impl ModelConfigLike for ContentConfig {
impl ContentConfig { fn max_seq_len(&self) -> usize {
pub fn max_seq_len(&self) -> usize {
self.max_seq_len self.max_seq_len
} }
pub fn hidden_size(&self) -> usize { fn hidden_size(&self) -> usize {
self.hidden_size self.hidden_size
} }
pub fn num_attn_heads(&self) -> usize { fn num_attn_heads(&self) -> usize {
self.num_attn_heads self.num_attn_heads
} }
pub fn num_kv_heads(&self) -> usize { fn num_kv_heads(&self) -> usize {
self.num_kv_heads self.num_kv_heads
} }
pub fn num_layers(&self) -> usize { fn num_layers(&self) -> usize {
self.num_layers self.num_layers
} }
pub fn k_head_dim(&self) -> usize { fn k_head_dim(&self) -> usize {
self.key_length self.key_length
.unwrap_or(self.hidden_size / self.num_attn_heads) .unwrap_or(self.hidden_size / self.num_attn_heads)
} }
pub fn v_head_dim(&self) -> usize { fn v_head_dim(&self) -> usize {
self.value_length self.value_length
.unwrap_or(self.hidden_size / self.num_attn_heads) .unwrap_or(self.hidden_size / self.num_attn_heads)
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs // Adapted from mistral.rs
// //
...@@ -496,7 +484,7 @@ mod tests { ...@@ -496,7 +484,7 @@ mod tests {
add_special_tokens: bool, add_special_tokens: bool,
) -> Result<String> { ) -> Result<String> {
let tokenized = tokenizer let tokenized = tokenizer
.encode(passage, add_special_tokens) .encode_fast(passage, add_special_tokens)
.map_err(anyhow::Error::msg)?; .map_err(anyhow::Error::msg)?;
// NOTE: The special tokens bool param meaning differs between encode() / decode(): // NOTE: The special tokens bool param meaning differs between encode() / decode():
...@@ -515,6 +503,7 @@ mod tests { ...@@ -515,6 +503,7 @@ mod tests {
#[test] #[test]
fn test_encode_decode_llama() -> Result<()> { fn test_encode_decode_llama() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
let passage = get_test_passage(); let passage = get_test_passage();
...@@ -539,7 +528,7 @@ mod tests { ...@@ -539,7 +528,7 @@ mod tests {
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>(); let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rand::rng()); tokens.shuffle(&mut rng());
// Without skipping special tokens // Without skipping special tokens
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?; let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
...@@ -556,6 +545,7 @@ mod tests { ...@@ -556,6 +545,7 @@ mod tests {
#[test] #[test]
fn test_encode_decode_gpt2() -> Result<()> { fn test_encode_decode_gpt2() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
let passage = get_test_passage(); let passage = get_test_passage();
...@@ -580,7 +570,7 @@ mod tests { ...@@ -580,7 +570,7 @@ mod tests {
#[allow(clippy::cast_possible_truncation)] #[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>(); let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rand::rng()); tokens.shuffle(&mut rng());
// Without skipping special tokens // Without skipping special tokens
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?; let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
......
...@@ -40,7 +40,7 @@ use serde::{Deserialize, Serialize}; ...@@ -40,7 +40,7 @@ use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
use url::Url; use url::Url;
use crate::gguf::{Content, ContentConfig}; use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned; use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
...@@ -386,11 +386,6 @@ impl ModelInfoType { ...@@ -386,11 +386,6 @@ impl ModelInfoType {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct HFConfig { struct HFConfig {
bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
/// denotes the mixin to the flattened data model which can be present /// denotes the mixin to the flattened data model which can be present
/// in the config.json file /// in the config.json file
architectures: Vec<String>, architectures: Vec<String>,
...@@ -398,6 +393,16 @@ struct HFConfig { ...@@ -398,6 +393,16 @@ struct HFConfig {
/// general model type /// general model type
model_type: String, model_type: String,
text_config: Option<HFTextConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig {
bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
/// max sequence length /// max sequence length
max_position_embeddings: usize, max_position_embeddings: usize,
...@@ -412,9 +417,13 @@ struct HFConfig { ...@@ -412,9 +417,13 @@ struct HFConfig {
} }
impl HFConfig { impl HFConfig {
async fn from_json_file(file: &String) -> Result<Arc<dyn ModelInfo>> { async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let contents = std::fs::read_to_string(file)?; let contents = std::fs::read_to_string(file)?;
let config: Self = serde_json::from_str(&contents)?; let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?;
config.text_config = Some(text_config);
}
Ok(Arc::new(config)) Ok(Arc::new(config))
} }
fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> { fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
...@@ -433,19 +442,21 @@ impl HFConfig { ...@@ -433,19 +442,21 @@ impl HFConfig {
let arch = content.arch().to_string(); let arch = content.arch().to_string();
Ok(Arc::new(HFConfig { Ok(Arc::new(HFConfig {
bos_token_id,
eos_token_id: Either::Left(eos_token_id),
architectures: vec![format!("{}ForCausalLM", capitalize(&arch))], architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
// "general.architecture" // "general.architecture"
model_type: arch, model_type: arch,
// "llama.context_length" text_config: Some(HFTextConfig {
max_position_embeddings: model_config_metadata.max_seq_len(), bos_token_id,
// "llama.block_count" eos_token_id: Either::Left(eos_token_id),
num_hidden_layers, // "llama.context_length"
// "llama.attention.head_count" max_position_embeddings: model_config_metadata.max_seq_len(),
num_attention_heads: model_config_metadata.num_attn_heads(), // "llama.block_count"
// "tokenizer.ggml.tokens".len() num_hidden_layers,
vocab_size, // "llama.attention.head_count"
num_attention_heads: model_config_metadata.num_attn_heads(),
// "tokenizer.ggml.tokens".len()
vocab_size,
}),
})) }))
} }
} }
...@@ -456,22 +467,22 @@ impl ModelInfo for HFConfig { ...@@ -456,22 +467,22 @@ impl ModelInfo for HFConfig {
} }
fn bos_token_id(&self) -> TokenIdType { fn bos_token_id(&self) -> TokenIdType {
self.bos_token_id self.text_config.as_ref().unwrap().bos_token_id
} }
fn eos_token_ids(&self) -> Vec<TokenIdType> { fn eos_token_ids(&self) -> Vec<TokenIdType> {
match &self.eos_token_id { match &self.text_config.as_ref().unwrap().eos_token_id {
Either::Left(eos_token_id) => vec![*eos_token_id], Either::Left(eos_token_id) => vec![*eos_token_id],
Either::Right(eos_token_ids) => eos_token_ids.clone(), Either::Right(eos_token_ids) => eos_token_ids.clone(),
} }
} }
fn max_position_embeddings(&self) -> usize { fn max_position_embeddings(&self) -> usize {
self.max_position_embeddings self.text_config.as_ref().unwrap().max_position_embeddings
} }
fn vocab_size(&self) -> usize { fn vocab_size(&self) -> usize {
self.vocab_size self.text_config.as_ref().unwrap().vocab_size
} }
} }
...@@ -504,3 +515,27 @@ fn capitalize(s: &str) -> String { ...@@ -504,3 +515,27 @@ fn capitalize(s: &str) -> String {
}) })
.collect() .collect()
} }
#[cfg(test)]
mod tests {
use super::HFConfig;
use std::path::Path;
#[tokio::test]
pub async fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 128000);
Ok(())
}
#[tokio::test]
pub async fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 200000);
Ok(())
}
}
{
"architectures": [
"Llama4ForConditionalGeneration"
],
"boi_token_index": 200080,
"eoi_token_index": 200081,
"image_token_index": 200092,
"model_type": "llama4",
"text_config": {
"_attn_implementation_autoset": true,
"attention_bias": false,
"attention_chunk_size": 8192,
"attention_dropout": 0.0,
"bos_token_id": 200000,
"eos_token_id": [
200001,
200007,
200008
],
"for_llm_compressor": false,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"interleave_moe_layer_step": 1,
"intermediate_size": 8192,
"intermediate_size_mlp": 16384,
"max_position_embeddings": 10485760,
"model_type": "llama4_text",
"no_rope_layers": [],
"num_attention_heads": 40,
"num_experts_per_tok": 1,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"num_local_experts": 16,
"output_router_logits": false,
"pad_token_id": 200018,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 16.0,
"high_freq_factor": 1.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"router_aux_loss_coef": 0.001,
"router_jitter_noise": 0.0,
"torch_dtype": "bfloat16",
"use_cache": true,
"use_qk_norm": true,
"vocab_size": 202048
},
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0.dev0",
"vision_config": {
"_attn_implementation_autoset": true,
"attention_dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1408,
"image_size": 336,
"initializer_range": 0.02,
"intermediate_size": 5632,
"model_type": "llama4_vision_model",
"multi_modal_projector_bias": false,
"norm_eps": 1e-05,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 34,
"patch_size": 14,
"pixel_shuffle_ratio": 0.5,
"projector_dropout": 0.0,
"projector_input_dim": 4096,
"projector_output_dim": 4096,
"rope_theta": 10000,
"vision_feature_layer": -1,
"vision_feature_select_strategy": "default",
"vision_output_dim": 4096
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment