Unverified Commit ceaeba3e authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Qwen3, Gemma3 and Llama4 support (#1002)

. New mistralrs and llamacpp version
. mistralrs: Handle Gemma 3 and Llama 4 as vision models
. Update the dynamo-run docs to use Qwen 3
. Our pre-processor now supports Llama 4's newer multi-modal `config.json`
. Upgrade minijinja to handle Qwen 3's prompt template

For Llama 4 we'll need to limit the max seq len. vllm says:
> To serve at least one request with the models's max seq len (10485760), (240.00 GiB KV cache is needed,...

I was able to run Llama 4 with llamacpp and a quantized GGUF, with Dynamo doing the pre-processing.
parent 57402e70
......@@ -672,7 +672,7 @@ dependencies = [
[[package]]
name = "candle-core"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4"
source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [
"byteorder",
"candle-kernels 0.8.0",
......@@ -722,7 +722,7 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4"
source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [
"bindgen_cuda 0.1.5",
]
......@@ -739,7 +739,7 @@ dependencies = [
[[package]]
name = "candle-metal-kernels"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4"
source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [
"metal",
"once_cell",
......@@ -750,7 +750,7 @@ dependencies = [
[[package]]
name = "candle-nn"
version = "0.8.0"
source = "git+https://github.com/EricLBuehler/candle.git?rev=496a8d2b#496a8d2bf8f88e3be4ea27332a209d66e8b404f4"
source = "git+https://github.com/EricLBuehler/candle.git?rev=cb2d8f5#cb2d8f59b4fbe1f69ef998233c40a87033a05b0d"
dependencies = [
"candle-core 0.8.0",
"candle-metal-kernels",
......@@ -950,7 +950,7 @@ dependencies = [
"encode_unicode",
"libc",
"once_cell",
"unicode-width",
"unicode-width 0.2.0",
"windows-sys 0.59.0",
]
......@@ -1118,6 +1118,29 @@ dependencies = [
"typenum",
]
[[package]]
name = "cssparser"
version = "0.34.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c66d1cd8ed61bf80b38432613a7a2f09401ab8d0501110655f8b341484a3e3"
dependencies = [
"cssparser-macros",
"dtoa-short",
"itoa",
"phf",
"smallvec",
]
[[package]]
name = "cssparser-macros"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13b588ba4ac1a99f7f2964d24b3d896ddc6bf847ee3855dbd4366f058cfcd331"
dependencies = [
"quote",
"syn 2.0.100",
]
[[package]]
name = "csv"
version = "1.3.1"
......@@ -1364,9 +1387,9 @@ dependencies = [
[[package]]
name = "derivre"
version = "0.3.1"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a3c2606b3ffc46f91fd62d954d55659ba9fb391bb673311b70f50daf9c15e49"
checksum = "4a605f30e6a1460a323cc4de7bc62dea81df1d9d67eb92194d3a983a8a9601c4"
dependencies = [
"anyhow",
"bytemuck",
......@@ -1456,6 +1479,21 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aac81fa3e28d21450aa4d2ac065992ba96a1d7303efbce51a95f4fd175b67562"
[[package]]
name = "dtoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6add3b8cff394282be81f3fc1a0605db594ed69890078ca6e2cab1c408bcf04"
[[package]]
name = "dtoa-short"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd1511a7b6a56299bd043a9c167a6d2bfb37bf84a6dfceaba651168adfb43c87"
dependencies = [
"dtoa",
]
[[package]]
name = "dyn-clone"
version = "1.0.19"
......@@ -1718,6 +1756,12 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "ego-tree"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2972feb8dffe7bc8c5463b1dacda1b0dfbed3710e50f977d965429692d74cd8"
[[package]]
name = "either"
version = "1.15.0"
......@@ -2066,6 +2110,16 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
[[package]]
name = "futf"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843"
dependencies = [
"mac",
"new_debug_unreachable",
]
[[package]]
name = "futures"
version = "0.3.31"
......@@ -2161,6 +2215,15 @@ dependencies = [
"slab",
]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]]
name = "galil-seiferas"
version = "0.1.5"
......@@ -2418,6 +2481,15 @@ dependencies = [
"version_check",
]
[[package]]
name = "getopts"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
dependencies = [
"unicode-width 0.1.14",
]
[[package]]
name = "getrandom"
version = "0.2.16"
......@@ -2599,6 +2671,42 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "html2text"
version = "0.14.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1637acec3b965bab873352189d887b12c87b4f8d7571f4d185e796be5654ad8"
dependencies = [
"html5ever 0.31.0",
"tendril",
"thiserror 2.0.12",
"unicode-width 0.2.0",
]
[[package]]
name = "html5ever"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b7410cae13cbc75623c98ac4cbfd1f0bedddf3227afc24f370cf0f50a44a11c"
dependencies = [
"log",
"mac",
"markup5ever 0.14.1",
"match_token",
]
[[package]]
name = "html5ever"
version = "0.31.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953cbbe631aae7fc0a112702ad5d3aaf09da38beaf45ea84610d6e1c358f569c"
dependencies = [
"log",
"mac",
"markup5ever 0.16.1",
"match_token",
]
[[package]]
name = "http"
version = "0.2.0"
......@@ -3011,7 +3119,7 @@ dependencies = [
"number_prefix",
"portable-atomic",
"rayon",
"unicode-width",
"unicode-width 0.2.0",
"web-time",
]
......@@ -3316,9 +3424,9 @@ checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856"
[[package]]
name = "llama-cpp-2"
version = "0.1.102"
version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a419bb48efa0f8389a82301f1f64e2874568a3fbf6f62f8ddab5324382b82768"
checksum = "401c708926326b1ee410735dc348882c73deeab78f1f89ff2c9caf148356feb4"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
......@@ -3329,9 +3437,9 @@ dependencies = [
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.102"
version = "0.1.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0522f9894e22dd988dd2e34222bda7acba53a0dcce744ca6d8ddce905ba33a4e"
checksum = "8b4ae3037b7d9b9fab9fd7905aeb04e214acb300599fa1ee698d6f759ee530f9"
dependencies = [
"bindgen",
"cc",
......@@ -3343,8 +3451,9 @@ dependencies = [
[[package]]
name = "llguidance"
version = "0.7.0"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386"
version = "0.7.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c50686e6724ff55e184447dbc775b94f28f64e5061882b75c84b12b9c1d613"
dependencies = [
"anyhow",
"derivre",
......@@ -3352,7 +3461,7 @@ dependencies = [
"regex-syntax 0.8.5",
"serde",
"serde_json",
"toktrie 0.7.0",
"toktrie 0.7.19",
]
[[package]]
......@@ -3410,6 +3519,12 @@ dependencies = [
"vob",
]
[[package]]
name = "mac"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4"
[[package]]
name = "macro_rules_attribute"
version = "0.2.0"
......@@ -3435,6 +3550,42 @@ dependencies = [
"libc",
]
[[package]]
name = "markup5ever"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7a7213d12e1864c0f002f52c2923d4556935a43dec5e71355c2760e0f6e7a18"
dependencies = [
"log",
"phf",
"phf_codegen",
"string_cache",
"string_cache_codegen",
"tendril",
]
[[package]]
name = "markup5ever"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a8096766c229e8c88a3900c9b44b7e06aa7f7343cc229158c3e58ef8f9973a"
dependencies = [
"log",
"tendril",
"web_atoms",
]
[[package]]
name = "match_token"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88a9689d8d44bf9964484516275f5cd4c9b59457a6940c1d5d0ecbb94510a36b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "matchers"
version = "0.1.0"
......@@ -3549,9 +3700,9 @@ dependencies = [
[[package]]
name = "minijinja"
version = "2.9.0"
version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98642a6dfca91122779a307b77cd07a4aa951fbe32232aaf5bad9febc66be754"
checksum = "dd72e8b4e42274540edabec853f607c015c73436159b06c39c7af85a20433155"
dependencies = [
"memo-map",
"self_cell",
......@@ -3561,9 +3712,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
version = "2.9.0"
version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd4a0f6e171c7bb92ed2caf446fa3de4e26561cea1d97085103e9cb42359dd59"
checksum = "457f85f9c4c5b17d11fcf9bbe7c0dbba64843c5ee040005956f1a510b6679fe2"
dependencies = [
"minijinja",
"serde",
......@@ -3641,8 +3792,8 @@ dependencies = [
[[package]]
name = "mistralrs"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646"
version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [
"anyhow",
"candle-core 0.8.0",
......@@ -3662,8 +3813,8 @@ dependencies = [
[[package]]
name = "mistralrs-core"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646"
version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [
"akin",
"anyhow",
......@@ -3688,6 +3839,7 @@ dependencies = [
"galil-seiferas",
"half",
"hf-hub",
"html2text",
"image",
"indexmap 2.9.0",
"indicatif",
......@@ -3703,6 +3855,7 @@ dependencies = [
"mistralrs-vision",
"objc",
"once_cell",
"ordered-float",
"radix_trie",
"rand 0.9.1",
"rand_isaac",
......@@ -3713,6 +3866,7 @@ dependencies = [
"rustc-hash 2.1.1",
"safetensors",
"schemars",
"scraper",
"serde",
"serde-big-array",
"serde_json",
......@@ -3724,11 +3878,12 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-rayon",
"toktrie_hf_tokenizers 0.7.0",
"toktrie_hf_tokenizers 0.7.19",
"toml",
"tqdm",
"tracing",
"tracing-subscriber",
"urlencoding",
"uuid 1.16.0",
"variantly",
"vob",
......@@ -3736,8 +3891,8 @@ dependencies = [
[[package]]
name = "mistralrs-paged-attn"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646"
version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [
"anyhow",
"bindgen_cuda 0.1.6",
......@@ -3751,8 +3906,8 @@ dependencies = [
[[package]]
name = "mistralrs-quant"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646"
version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [
"bindgen_cuda 0.1.5",
"byteorder",
......@@ -3778,11 +3933,12 @@ dependencies = [
[[package]]
name = "mistralrs-vision"
version = "0.4.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=aaafc2ef#aaafc2efc6305c1a79eee632b177d76586df1646"
version = "0.5.0"
source = "git+https://github.com/EricLBuehler/mistral.rs.git?rev=ebd50e35e#ebd50e35efacb082462b757fb448d4ab54473775"
dependencies = [
"candle-core 0.8.0",
"image",
"rayon",
]
[[package]]
......@@ -3863,6 +4019,12 @@ dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "new_debug_unreachable"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "nibble_vec"
version = "0.1.0"
......@@ -4154,6 +4316,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ordered-float"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2c1f9f56e534ac6a9b8a4600bdf0f530fb393b5f393e7b4d03489c3cf0c3f01"
dependencies = [
"num-traits",
]
[[package]]
name = "overload"
version = "0.1.1"
......@@ -4178,7 +4349,7 @@ checksum = "b915f831b85d984193fdc3d3611505871dc139b2534530fa01c1a6a6707b6723"
dependencies = [
"bytecount",
"fnv",
"unicode-width",
"unicode-width 0.2.0",
]
[[package]]
......@@ -4309,6 +4480,58 @@ dependencies = [
"indexmap 2.9.0",
]
[[package]]
name = "phf"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
dependencies = [
"phf_macros",
"phf_shared",
]
[[package]]
name = "phf_codegen"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a"
dependencies = [
"phf_generator",
"phf_shared",
]
[[package]]
name = "phf_generator"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d"
dependencies = [
"phf_shared",
"rand 0.8.5",
]
[[package]]
name = "phf_macros"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "phf_shared"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
dependencies = [
"siphasher",
]
[[package]]
name = "pin-project"
version = "1.1.10"
......@@ -4400,6 +4623,12 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "precomputed-hash"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c"
[[package]]
name = "prettyplease"
version = "0.2.32"
......@@ -5423,6 +5652,21 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "scraper"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "527e65d9d888567588db4c12da1087598d0f6f8b346cc2c5abc91f05fc2dffe2"
dependencies = [
"cssparser",
"ego-tree",
"getopts",
"html5ever 0.29.1",
"precomputed-hash",
"selectors",
"tendril",
]
[[package]]
name = "secrecy"
version = "0.10.3"
......@@ -5469,6 +5713,25 @@ dependencies = [
"libc",
]
[[package]]
name = "selectors"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd568a4c9bb598e291a08244a5c1f5a8a6650bee243b5b0f8dbb3d9cc1d87fe8"
dependencies = [
"bitflags 2.9.0",
"cssparser",
"derive_more",
"fxhash",
"log",
"new_debug_unreachable",
"phf",
"phf_codegen",
"precomputed-hash",
"servo_arc",
"smallvec",
]
[[package]]
name = "self_cell"
version = "1.2.0"
......@@ -5639,6 +5902,15 @@ dependencies = [
"unsafe-libyaml",
]
[[package]]
name = "servo_arc"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae65c4249478a2647db249fb43e23cec56a2c8974a427e7bd8cb5a1d0964921a"
dependencies = [
"stable_deref_trait",
]
[[package]]
name = "sha2"
version = "0.10.8"
......@@ -5735,6 +6007,12 @@ version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "siphasher"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
[[package]]
name = "slab"
version = "0.4.9"
......@@ -5811,6 +6089,31 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "string_cache"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf776ba3fa74f83bf4b63c3dcbbf82173db2632ed8452cb2d891d33f459de70f"
dependencies = [
"new_debug_unreachable",
"parking_lot",
"phf_shared",
"precomputed-hash",
"serde",
]
[[package]]
name = "string_cache_codegen"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c711928715f1fe0fe509c53b43e993a9a557babc2d0a3567d0a3006f1ac931a0"
dependencies = [
"phf_generator",
"phf_shared",
"proc-macro2",
"quote",
]
[[package]]
name = "strsim"
version = "0.10.0"
......@@ -6049,6 +6352,17 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "tendril"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0"
dependencies = [
"futf",
"mac",
"utf-8",
]
[[package]]
name = "terminal_size"
version = "0.4.2"
......@@ -6320,8 +6634,9 @@ dependencies = [
[[package]]
name = "toktrie"
version = "0.7.0"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386"
version = "0.7.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69200397314cd578aa0623a5161342c44877ac2ebcea1e9d20e2f65120e81e8d"
dependencies = [
"anyhow",
"bytemuck",
......@@ -6346,15 +6661,16 @@ dependencies = [
[[package]]
name = "toktrie_hf_tokenizers"
version = "0.7.0"
source = "git+https://github.com/EricLBuehler/llguidance?rev=8d71957#8d7195774a209038ddfbb0d1a5348ed17b387386"
version = "0.7.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c03d846b7e3cd072e955583da31732076eed2f7be6881d7605fd4c3e5a8247"
dependencies = [
"anyhow",
"log",
"serde",
"serde_json",
"tokenizers",
"toktrie 0.7.0",
"toktrie 0.7.19",
]
[[package]]
......@@ -6694,6 +7010,12 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "unicode-width"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af"
[[package]]
name = "unicode-width"
version = "0.2.0"
......@@ -6755,6 +7077,18 @@ dependencies = [
"serde",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf16_iter"
version = "1.0.5"
......@@ -7011,6 +7345,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web_atoms"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b9c5f0bc545ea3b20b423e33b9b457764de0b3730cd957f6c6aa6c301785f6e"
dependencies = [
"phf",
"phf_codegen",
"string_cache",
"string_cache_codegen",
]
[[package]]
name = "webpki-roots"
version = "0.26.8"
......
......@@ -26,7 +26,7 @@ Usage:
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn://<path> [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin]
```
Example: `dynamo run Qwen/Qwen2.5-3B-Instruct`.
Example: `dynamo run Qwen/Qwen3-0.6B`
Set environment variable `DYN_LOG` to adjust logging level, e.g. `export DYN_LOG=debug`. It has the same syntax as `RUST_LOG`, ask AI for details.
......@@ -38,9 +38,9 @@ The vllm and sglang engines require [etcd](https://etcd.io/) and [nats](https://
### Use model from Hugging Face
This will automatically download Qwen2.5 3B from Hugging Face (6 GiB download) and start it in interactive text mode:
This will automatically download Qwen3 4B from Hugging Face (16 GiB download) and start it in interactive text mode:
```
dynamo run out=vllm Qwen/Qwen2.5-3B-Instruct
dynamo run out=vllm Qwen/Qwen3-4B
```
General format for HF download:
......@@ -65,12 +65,12 @@ curl -L -o Llama-3.2-3B-Instruct-Q4_K_M.gguf "https://huggingface.co/bartowski/L
#### Run model from local file
**Text interface**
```
dynamo run out=vllm Llama-3.2-3B-Instruct-Q4_K_M.gguf # or path to a Hugging Face repo checkout instead of the GGUF
dynamo run Llama-3.2-3B-Instruct-Q4_K_M.gguf # or path to a Hugging Face repo checkout instead of the GGUF
```
**HTTP interface**
```
dynamo run in=http out=vllm Llama-3.2-3B-Instruct-Q4_K_M.gguf
dynamo run in=http out=mistralrs Llama-3.2-3B-Instruct-Q4_K_M.gguf
```
**List the models**
......@@ -94,7 +94,7 @@ You will need [etcd](https://etcd.io/) and [nats](https://nats.io) with jetstrea
OpenAI compliant HTTP server, optional pre-processing, worker discovery.
```
dynamo run in=http out=dyn://llama3B_pool
dynamo-run in=http out=dyn://llama3B_pool
```
**Node 2:**
......@@ -109,11 +109,11 @@ This will use etcd to auto-discover the model and NATS to talk to it. You can ru
The `llama3B_pool` name is purely symbolic, pick anything as long as it matches the other node.
Run `dynamo run --help` for more options.
Run `dynamo-run --help` for more options.
## Full usage details
`dynamo-run` is what `dynamo run` executes. It is an example of what you can build in Rust with the `dynamo-llm` and `dynamo-runtime`. The following guide demonstrates how you can build from source with all the features.
`dynamo-run` is what `dynamo run` executes. It is also an example of what you can build in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide demonstrates how you can build from source with all the features.
### Setup
......@@ -181,13 +181,13 @@ Build with `--release` for a smaller binary and better performance, but longer b
[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a pure Rust engine that is fast to run, fast to load, supports GGUF as well as safetensors, and runs well on CPU as well as GPU. For those reasons it is the default engine.
```
dynamo-run Qwen/Qwen2.5-3B-Instruct
dynamo-run Qwen/Qwen3-4B
```
is equivalent to
```
dynamo-run in=text out=mistralrs Qwen/Qwen2.5-3B-Instruct
dynamo-run in=text out=mistralrs Qwen/Qwen3-4B
```
If you have multiple GPUs, mistral.rs does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
......@@ -204,7 +204,14 @@ cargo build --features llamacpp[,cuda|metal|vulkan] -p dynamo-run
dynamo-run out=llamacpp ~/llms/Llama-3.2-3B-Instruct-Q6_K.gguf
```
llamacpp is best for single-GPU inference with a quantized GGUF model file.
Note that in some cases we are unable to extract the tokenizer from the GGUF, and so a Hugging Face checkout of a matching model must also be passed. Dynamo will use the weights from the GGUF and the pre-processor (`tokenizer.json`, etc) from the `--model-config`:
```
dynamo-run out=llamacpp ~/llms/gemma-3-1b-it-q4_0.gguf --model-config ~/llms/gemma-3-1b-it
dynamo-run out=llamacpp ~/llms/Llama-4-Scout-17B-16E-Instruct-UD-IQ1_S.gguf --model-config ~/llms/Llama-4-Scout-17B-16E-Instruct
```
If you have multiple GPUs, llama.cpp does automatic tensor parallelism. You do not need to pass any extra flags to dynamo-run to enable it.
### sglang
......@@ -233,9 +240,13 @@ To pass extra arguments to the sglang engine see *Extra engine arguments* below.
**Multi-GPU**
Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`. To specify which GPU to start from pass `--base-gpu-id <num>`.
Pass `--tensor-parallel-size <NUM-GPUS>` to `dynamo-run`.
```
dynamo-run out=sglang ~/llms/Llama-4-Scout-17B-16E-Instruct/ --tensor-parallel-size 8
```
For example on a shared eight GPU machine where GPUs 0-3 are already in use:
To specify which GPU to start from pass `--base-gpu-id <num>`, for example on a shared eight GPU machine where GPUs 0-3 are already in use:
```
dynamo-run out=sglang <model> --tensor-parallel-size 4 --base-gpu-id 4
```
......
......@@ -16,10 +16,6 @@ pub use dynamo_llm::request_template::RequestTemplate;
pub use opt::{Input, Output};
mod subprocess;
/// When `in=text` the user doesn't need to know the model name, and doesn't need to provide it on
/// the command line. Hence it's optional, and defaults to this.
const INVISIBLE_MODEL_NAME: &str = "dynamo-run";
const CHILD_STOP_TIMEOUT: Duration = Duration::from_secs(2);
/// How we identify a python string endpoint
......@@ -64,15 +60,10 @@ pub async fn run(
_ => {
match &maybe_path {
Some(model_path) => {
let maybe_model_name = if in_opt == Input::Text {
Some(INVISIBLE_MODEL_NAME.to_string())
} else {
flags.model_name.clone()
};
LocalModel::prepare(
model_path.to_str().context("Invalid UTF-8 in model path")?,
flags.model_config.as_deref(),
maybe_model_name,
flags.model_name.clone(),
)
.await?
}
......@@ -121,7 +112,7 @@ pub async fn run(
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => EngineConfig::StaticFull {
engine: dynamo_engine_mistralrs::make_engine(local_model.path()).await?,
engine: dynamo_engine_mistralrs::make_engine(&local_model).await?,
model: Box::new(local_model),
},
Output::SgLang => {
......
......@@ -2511,9 +2511,9 @@ dependencies = [
[[package]]
name = "minijinja"
version = "2.9.0"
version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98642a6dfca91122779a307b77cd07a4aa951fbe32232aaf5bad9febc66be754"
checksum = "dd72e8b4e42274540edabec853f607c015c73436159b06c39c7af85a20433155"
dependencies = [
"memo-map",
"self_cell",
......@@ -2522,9 +2522,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
version = "2.9.0"
version = "2.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd4a0f6e171c7bb92ed2caf446fa3de4e26561cea1d97085103e9cb42359dd59"
checksum = "457f85f9c4c5b17d11fcf9bbe7c0dbba64843c5ee040005956f1a510b6679fe2"
dependencies = [
"minijinja",
"serde",
......
......@@ -38,4 +38,4 @@ async-stream = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
llama-cpp-2 = { version = "0.1.102" }
llama-cpp-2 = { version = "0.1.103" }
......@@ -228,9 +228,8 @@ fn run_request(
limit,
);
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let mut batch = LlamaBatch::new(std::cmp::max(512, max_output_tokens as usize), 1);
let last_index: i32 = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
......
......@@ -40,7 +40,7 @@ async-trait = { workspace = true }
candle-core = { version = "0.8.0" }
either = { workspace = true }
indexmap = { version = "2.6" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "aaafc2ef" }
mistralrs = { git = "https://github.com/EricLBuehler/mistral.rs.git", rev = "ebd50e35e" }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
......@@ -14,7 +14,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::{num::NonZero, path::Path, sync::Arc};
use std::{num::NonZero, sync::Arc};
use async_openai::types::FinishReason;
use async_stream::stream;
......@@ -23,9 +23,10 @@ use either::Either;
use indexmap::IndexMap;
use mistralrs::{
AutoDeviceMapParams, Constraint, DefaultSchedulerMethod, Device, DeviceMapSetting,
GGUFLoaderBuilder, GGUFSpecificConfig, MemoryGpuConfig, MistralRs, MistralRsBuilder,
GGUFLoaderBuilder, GGUFSpecificConfig, IsqType, MemoryGpuConfig, MistralRs, MistralRsBuilder,
ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, PagedAttentionConfig,
Request, RequestMessage, ResponseOk, SamplingParams, SchedulerConfig, StopTokens, TokenSource,
VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig,
};
use tokio::sync::mpsc::channel;
......@@ -40,6 +41,7 @@ use dynamo_llm::protocols::openai::{
};
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
use dynamo_llm::LocalModel;
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
......@@ -51,8 +53,11 @@ const PAGED_ATTENTION_MAX_NUM_SEQS: usize = 10;
/// finish_reason=stop and no tokens for one of the requests.
const EXP_ENABLE_PAGED_ATTENTION: bool = false;
pub async fn make_engine(gguf_path: &Path) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
let engine = MistralRsEngine::new(gguf_path).await?;
/// Initial message we send to mistral.rs to warm it up. We may not need this.
const WARMUP_MESSAGE: &str = "This is a test message. Respond only with 'OK'.";
pub async fn make_engine(model: &LocalModel) -> pipeline_error::Result<Arc<dyn StreamingEngine>> {
let engine = MistralRsEngine::new(model).await?;
let engine: Arc<dyn StreamingEngine> = Arc::new(EngineDispatcher::new(engine));
Ok(engine)
}
......@@ -74,7 +79,14 @@ struct MistralRsEngine {
}
impl MistralRsEngine {
async fn new(model_path: &Path) -> pipeline_error::Result<Self> {
async fn new(model: &LocalModel) -> pipeline_error::Result<Self> {
let model_path = model.path();
// Name some None's for clarity
let chat_template = None;
let tokenizer_json = None;
let no_kv_cache = false;
let jinja_explicit = None;
let display_name = model.display_name();
let loader = if model_path.is_file() {
// Load from a GGUF
let Some(model_filename) = model_path.file_name() else {
......@@ -85,7 +97,7 @@ impl MistralRsEngine {
};
GGUFLoaderBuilder::new(
None,
chat_template,
None,
model_dir.display().to_string(),
vec![model_filename.to_string_lossy().into_owned()],
......@@ -93,24 +105,35 @@ impl MistralRsEngine {
prompt_chunksize: None,
topology: None,
},
no_kv_cache,
jinja_explicit,
)
.build()
} else if is_vision_model(display_name) {
let vlt = if is_gemma3(display_name) {
VisionLoaderType::Gemma3
} else if is_llama4(display_name) {
VisionLoaderType::Llama4
} else {
panic!("Unsupported vision model {display_name}");
};
VisionLoaderBuilder::new(
VisionSpecificConfig::default(),
chat_template,
tokenizer_json,
Some(model_path.display().to_string()),
jinja_explicit,
)
.build(vlt)
} else {
// Load from a HF repo dir
NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn: false,
prompt_chunksize: None,
topology: None,
organization: Default::default(),
write_uqff: None,
from_uqff: None,
imatrix: None,
calibration_file: None,
},
None,
None,
NormalSpecificConfig::default(),
chat_template,
tokenizer_json,
Some(model_path.display().to_string()),
no_kv_cache,
jinja_explicit,
)
.build(None)?
};
......@@ -127,6 +150,21 @@ impl MistralRsEngine {
} else {
None
};
let device_map_params = if is_vision_model(model.display_name()) {
AutoDeviceMapParams::Vision {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
max_image_shape: (0, 0),
max_num_images: 0,
}
} else {
AutoDeviceMapParams::Text {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}
};
// Load, into a Pipeline
let pipeline = loader.load_model_from_hf(
None,
......@@ -134,11 +172,12 @@ impl MistralRsEngine {
&ModelDType::Auto,
&best_device()?,
false,
DeviceMapSetting::Auto(AutoDeviceMapParams::Text {
max_seq_len,
max_batch_size: AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE,
}),
None,
DeviceMapSetting::Auto(device_map_params),
if is_llama4(display_name) {
Some(IsqType::Q4K)
} else {
None
},
paged_attention_config,
)?;
let scheduler = if cfg!(feature = "cuda") && EXP_ENABLE_PAGED_ATTENTION {
......@@ -154,14 +193,21 @@ impl MistralRsEngine {
config,
}
} else {
tracing::debug!("Using mistralrs DefaultScheduler");
SchedulerConfig::DefaultScheduler {
// Safety: unwrap trivially safe here
method: DefaultSchedulerMethod::Fixed(NonZero::new(max_seq_len).unwrap()),
}
};
// Create the MistralRs, which is a runner
let builder = MistralRsBuilder::new(pipeline.clone(), scheduler).with_prefix_cache_n(16);
let throughput_logging = false;
let search_embedding_model = None;
let builder = MistralRsBuilder::new(
pipeline.clone(),
scheduler,
throughput_logging,
search_embedding_model,
)
.with_prefix_cache_n(16);
let engine = MistralRsEngine {
mistralrs: builder.build(),
};
......@@ -174,21 +220,27 @@ impl MistralRsEngine {
let request_id = engine.mistralrs.next_request_id();
let warmup_request = Request::Normal(NormalRequest {
id: request_id,
messages: RequestMessage::Chat(vec![IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
("content".to_string(), Either::Left("test".to_string())),
])]),
messages: RequestMessage::Chat {
messages: vec![IndexMap::from([
("role".to_string(), Either::Left("user".to_string())),
(
"content".to_string(),
Either::Left(WARMUP_MESSAGE.to_string()),
),
])],
enable_thinking: Some(false),
},
sampling_params: SamplingParams::deterministic(),
response: tx,
return_logprobs: false,
is_streaming: false,
constraint: Constraint::None,
suffix: None,
adapters: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
// Send warmup request and consume response
......@@ -285,18 +337,21 @@ impl
let request_id = self.mistralrs.next_request_id();
let mistralrs_request = Request::Normal(NormalRequest {
id: request_id,
messages: RequestMessage::Chat(messages),
messages: RequestMessage::Chat {
messages,
enable_thinking: None,
},
sampling_params,
response: tx,
return_logprobs: request.inner.logprobs.unwrap_or_default(),
is_streaming: true,
constraint: Constraint::None,
suffix: None,
adapters: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
......@@ -477,11 +532,11 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
is_streaming: true,
constraint: Constraint::None,
suffix: None,
adapters: None,
tools: None,
tool_choice: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
});
self.mistralrs.get_sender()?.send(mistralrs_request).await?;
......@@ -533,3 +588,15 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
Ok(ResponseStream::new(Box::pin(output), ctx))
}
}
fn is_vision_model(s: &str) -> bool {
is_gemma3(s) || is_llama4(s)
}
fn is_gemma3(s: &str) -> bool {
s.to_lowercase().contains("gemma-3")
}
fn is_llama4(s: &str) -> bool {
s.to_lowercase().contains("llama-4")
}
......@@ -99,8 +99,8 @@ toktrie_hf_tokenizers = { version = "0.6.28" }
bs62 = { version = "0.1" }
erased-serde = { version = "0.4" }
itertools = { version = "0.14.0" }
minijinja = { version = "2.3.1", features = ["loader"] }
minijinja-contrib = { version = "2.3.1", features = ["pycompat"] }
minijinja = { version = "2.10.2", features = ["loader"] }
minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
# GGUF
ggus = "0.4.0"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs
//
......@@ -45,6 +33,7 @@ use strum::EnumString;
use anyhow::{Context, Result};
pub(crate) use content::Content;
pub(crate) use gguf_metadata::ContentConfig;
pub use gguf_metadata::ModelConfigLike;
pub(crate) use gguf_tokenizer::convert_gguf_to_hf_tokenizer;
use std::str::FromStr;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs
//
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs
//
......@@ -46,6 +34,16 @@ use tracing::warn;
use crate::gguf::Content;
pub trait ModelConfigLike {
fn max_seq_len(&self) -> usize;
fn num_layers(&self) -> usize;
fn hidden_size(&self) -> usize;
fn num_kv_heads(&self) -> usize;
fn num_attn_heads(&self) -> usize;
fn k_head_dim(&self) -> usize;
fn v_head_dim(&self) -> usize;
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct ContentConfig {
......@@ -87,28 +85,27 @@ impl From<&Content> for ContentConfig {
}
}
#[allow(dead_code)]
impl ContentConfig {
pub fn max_seq_len(&self) -> usize {
impl ModelConfigLike for ContentConfig {
fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn hidden_size(&self) -> usize {
fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn num_attn_heads(&self) -> usize {
fn num_attn_heads(&self) -> usize {
self.num_attn_heads
}
pub fn num_kv_heads(&self) -> usize {
fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
pub fn num_layers(&self) -> usize {
fn num_layers(&self) -> usize {
self.num_layers
}
pub fn k_head_dim(&self) -> usize {
fn k_head_dim(&self) -> usize {
self.key_length
.unwrap_or(self.hidden_size / self.num_attn_heads)
}
pub fn v_head_dim(&self) -> usize {
fn v_head_dim(&self) -> usize {
self.value_length
.unwrap_or(self.hidden_size / self.num_attn_heads)
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Adapted from mistral.rs
//
......@@ -496,7 +484,7 @@ mod tests {
add_special_tokens: bool,
) -> Result<String> {
let tokenized = tokenizer
.encode(passage, add_special_tokens)
.encode_fast(passage, add_special_tokens)
.map_err(anyhow::Error::msg)?;
// NOTE: The special tokens bool param meaning differs between encode() / decode():
......@@ -515,6 +503,7 @@ mod tests {
#[test]
fn test_encode_decode_llama() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom;
let passage = get_test_passage();
......@@ -539,7 +528,7 @@ mod tests {
#[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rand::rng());
tokens.shuffle(&mut rng());
// Without skipping special tokens
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
......@@ -556,6 +545,7 @@ mod tests {
#[test]
fn test_encode_decode_gpt2() -> Result<()> {
use rand::rng;
use rand::seq::SliceRandom;
let passage = get_test_passage();
......@@ -580,7 +570,7 @@ mod tests {
#[allow(clippy::cast_possible_truncation)]
let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
tokens.shuffle(&mut rand::rng());
tokens.shuffle(&mut rng());
// Without skipping special tokens
let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
......
......@@ -40,7 +40,7 @@ use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
use url::Url;
use crate::gguf::{Content, ContentConfig};
use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType;
......@@ -386,11 +386,6 @@ impl ModelInfoType {
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFConfig {
bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
/// denotes the mixin to the flattened data model which can be present
/// in the config.json file
architectures: Vec<String>,
......@@ -398,6 +393,16 @@ struct HFConfig {
/// general model type
model_type: String,
text_config: Option<HFTextConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig {
bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
/// max sequence length
max_position_embeddings: usize,
......@@ -412,9 +417,13 @@ struct HFConfig {
}
impl HFConfig {
async fn from_json_file(file: &String) -> Result<Arc<dyn ModelInfo>> {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let contents = std::fs::read_to_string(file)?;
let config: Self = serde_json::from_str(&contents)?;
let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?;
config.text_config = Some(text_config);
}
Ok(Arc::new(config))
}
fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
......@@ -433,19 +442,21 @@ impl HFConfig {
let arch = content.arch().to_string();
Ok(Arc::new(HFConfig {
bos_token_id,
eos_token_id: Either::Left(eos_token_id),
architectures: vec![format!("{}ForCausalLM", capitalize(&arch))],
// "general.architecture"
model_type: arch,
// "llama.context_length"
max_position_embeddings: model_config_metadata.max_seq_len(),
// "llama.block_count"
num_hidden_layers,
// "llama.attention.head_count"
num_attention_heads: model_config_metadata.num_attn_heads(),
// "tokenizer.ggml.tokens".len()
vocab_size,
text_config: Some(HFTextConfig {
bos_token_id,
eos_token_id: Either::Left(eos_token_id),
// "llama.context_length"
max_position_embeddings: model_config_metadata.max_seq_len(),
// "llama.block_count"
num_hidden_layers,
// "llama.attention.head_count"
num_attention_heads: model_config_metadata.num_attn_heads(),
// "tokenizer.ggml.tokens".len()
vocab_size,
}),
}))
}
}
......@@ -456,22 +467,22 @@ impl ModelInfo for HFConfig {
}
fn bos_token_id(&self) -> TokenIdType {
self.bos_token_id
self.text_config.as_ref().unwrap().bos_token_id
}
fn eos_token_ids(&self) -> Vec<TokenIdType> {
match &self.eos_token_id {
match &self.text_config.as_ref().unwrap().eos_token_id {
Either::Left(eos_token_id) => vec![*eos_token_id],
Either::Right(eos_token_ids) => eos_token_ids.clone(),
}
}
fn max_position_embeddings(&self) -> usize {
self.max_position_embeddings
self.text_config.as_ref().unwrap().max_position_embeddings
}
fn vocab_size(&self) -> usize {
self.vocab_size
self.text_config.as_ref().unwrap().vocab_size
}
}
......@@ -504,3 +515,27 @@ fn capitalize(s: &str) -> String {
})
.collect()
}
#[cfg(test)]
mod tests {
use super::HFConfig;
use std::path::Path;
#[tokio::test]
pub async fn test_config_json_llama3() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 128000);
Ok(())
}
#[tokio::test]
pub async fn test_config_json_llama4() -> anyhow::Result<()> {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file.display().to_string()).await?;
assert_eq!(config.bos_token_id(), 200000);
Ok(())
}
}
{
"architectures": [
"Llama4ForConditionalGeneration"
],
"boi_token_index": 200080,
"eoi_token_index": 200081,
"image_token_index": 200092,
"model_type": "llama4",
"text_config": {
"_attn_implementation_autoset": true,
"attention_bias": false,
"attention_chunk_size": 8192,
"attention_dropout": 0.0,
"bos_token_id": 200000,
"eos_token_id": [
200001,
200007,
200008
],
"for_llm_compressor": false,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5120,
"initializer_range": 0.02,
"interleave_moe_layer_step": 1,
"intermediate_size": 8192,
"intermediate_size_mlp": 16384,
"max_position_embeddings": 10485760,
"model_type": "llama4_text",
"no_rope_layers": [],
"num_attention_heads": 40,
"num_experts_per_tok": 1,
"num_hidden_layers": 48,
"num_key_value_heads": 8,
"num_local_experts": 16,
"output_router_logits": false,
"pad_token_id": 200018,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 16.0,
"high_freq_factor": 1.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"router_aux_loss_coef": 0.001,
"router_jitter_noise": 0.0,
"torch_dtype": "bfloat16",
"use_cache": true,
"use_qk_norm": true,
"vocab_size": 202048
},
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0.dev0",
"vision_config": {
"_attn_implementation_autoset": true,
"attention_dropout": 0.0,
"hidden_act": "gelu",
"hidden_size": 1408,
"image_size": 336,
"initializer_range": 0.02,
"intermediate_size": 5632,
"model_type": "llama4_vision_model",
"multi_modal_projector_bias": false,
"norm_eps": 1e-05,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 34,
"patch_size": 14,
"pixel_shuffle_ratio": 0.5,
"projector_dropout": 0.0,
"projector_input_dim": 4096,
"projector_output_dim": 4096,
"rope_theta": 10000,
"vision_feature_layer": -1,
"vision_feature_select_strategy": "default",
"vision_output_dim": 4096
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment