Unverified Commit d18ed5cf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Mllama flash version (#2585)

* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
parent 584b4d7a
......@@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -205,9 +205,9 @@ dependencies = [
[[package]]
name = "autocfg"
version = "1.3.0"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "av1-grain"
......@@ -316,12 +316,12 @@ dependencies = [
[[package]]
name = "axum"
version = "0.7.6"
version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec"
checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [
"async-trait",
"axum-core 0.4.4",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.1.0",
......@@ -367,9 +367,9 @@ dependencies = [
[[package]]
name = "axum-core"
version = "0.4.4"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
......@@ -392,7 +392,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
dependencies = [
"axum 0.7.6",
"axum 0.7.7",
"futures-core",
"futures-util",
"http 1.1.0",
......@@ -456,7 +456,7 @@ dependencies = [
"regex",
"rustc-hash",
"shlex",
"syn 2.0.77",
"syn 2.0.79",
"which",
]
......@@ -605,9 +605,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.1.21"
version = "1.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0"
checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0"
dependencies = [
"jobserver",
"libc",
......@@ -704,7 +704,7 @@ dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -971,7 +971,7 @@ dependencies = [
"proc-macro2",
"quote",
"scratch",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1012,7 +1012,7 @@ dependencies = [
"proc-macro2",
"quote",
"strsim",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1053,7 +1053,7 @@ dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
dependencies = [
"derive_builder_core",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]]
name = "fdeflate"
version = "0.3.4"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645"
checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
dependencies = [
"simd-adler32",
]
......@@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.33"
version = "1.0.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253"
checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
dependencies = [
"crc32fast",
"miniz_oxide 0.8.0",
......@@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
dependencies = [
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
dependencies = [
"adler",
"simd-adler32",
]
[[package]]
......@@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
......@@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -2599,9 +2599,12 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.19.0"
version = "1.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
dependencies = [
"portable-atomic",
]
[[package]]
name = "onig"
......@@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -2808,7 +2811,7 @@ dependencies = [
"glob",
"once_cell",
"opentelemetry 0.21.0",
"ordered-float 4.2.2",
"ordered-float 4.3.0",
"percent-encoding",
"rand",
"thiserror",
......@@ -2828,7 +2831,7 @@ dependencies = [
"lazy_static",
"once_cell",
"opentelemetry 0.23.0",
"ordered-float 4.2.2",
"ordered-float 4.3.0",
"percent-encoding",
"rand",
"thiserror",
......@@ -2851,9 +2854,9 @@ dependencies = [
[[package]]
name = "ordered-float"
version = "4.2.2"
version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6"
checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537"
dependencies = [
"num-traits",
]
......@@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -2988,22 +2991,22 @@ dependencies = [
[[package]]
name = "png"
version = "0.17.13"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1"
checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
dependencies = [
"bitflags 1.3.2",
"crc32fast",
"fdeflate",
"flate2",
"miniz_oxide 0.7.4",
"miniz_oxide 0.8.0",
]
[[package]]
name = "portable-atomic"
version = "1.8.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
[[package]]
name = "powerfmt"
......@@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
dependencies = [
"proc-macro2",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3119,7 +3122,7 @@ dependencies = [
"prost 0.12.6",
"prost-types",
"regex",
"syn 2.0.77",
"syn 2.0.79",
"tempfile",
]
......@@ -3146,7 +3149,7 @@ dependencies = [
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3205,7 +3208,7 @@ dependencies = [
"proc-macro2",
"pyo3-macros-backend",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3218,7 +3221,7 @@ dependencies = [
"proc-macro2",
"pyo3-build-config",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3402,9 +3405,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.5"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0"
checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [
"bitflags 2.6.0",
]
......@@ -3422,14 +3425,14 @@ dependencies = [
[[package]]
name = "regex"
version = "1.10.6"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [
"aho-corasick",
"memchr",
"regex-automata 0.4.7",
"regex-syntax 0.8.4",
"regex-automata 0.4.8",
"regex-syntax 0.8.5",
]
[[package]]
......@@ -3443,13 +3446,13 @@ dependencies = [
[[package]]
name = "regex-automata"
version = "0.4.7"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [
"aho-corasick",
"memchr",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
]
[[package]]
......@@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]]
name = "regex-syntax"
version = "0.8.4"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
......@@ -3563,7 +3566,7 @@ dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"syn 2.0.77",
"syn 2.0.79",
"walkdir",
]
......@@ -3686,9 +3689,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.8.0"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
[[package]]
name = "rustls-webpki"
......@@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -3840,9 +3843,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "0.6.7"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [
"serde",
]
......@@ -4028,7 +4031,7 @@ dependencies = [
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -4050,9 +4053,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.77"
version = "2.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
dependencies = [
"proc-macro2",
"quote",
......@@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
version = "3.12.0"
version = "3.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
dependencies = [
"cfg-if",
"fastrand",
......@@ -4259,7 +4262,7 @@ version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
......@@ -4308,7 +4311,7 @@ version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
......@@ -4357,7 +4360,7 @@ version = "2.3.1-dev0"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.6",
"axum 0.7.7",
"axum-tracing-opentelemetry",
"base64 0.22.1",
"clap 4.5.18",
......@@ -4428,7 +4431,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -4533,7 +4536,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
......@@ -4566,7 +4569,7 @@ dependencies = [
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.8.4",
"regex-syntax 0.8.5",
"serde",
"serde_json",
"spm_precompiled",
......@@ -4612,7 +4615,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -4771,7 +4774,7 @@ dependencies = [
"proc-macro2",
"prost-build",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -4858,7 +4861,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -5151,7 +5154,7 @@ dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -5160,7 +5163,7 @@ version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
dependencies = [
"axum 0.7.6",
"axum 0.7.7",
"mime_guess",
"regex",
"rust-embed",
......@@ -5189,7 +5192,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......@@ -5290,7 +5293,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
"wasm-bindgen-shared",
]
......@@ -5324,7 +5327,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
......@@ -5668,9 +5671,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.19"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587"
checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
dependencies = [
"memchr",
]
......@@ -5703,7 +5706,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.77",
"syn 2.0.79",
]
[[package]]
......
......@@ -28,11 +28,17 @@ class ToolCall(BaseModel):
function: dict
class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel):
# Role of the message sender
role: str
# Content of the message
content: Optional[str] = None
content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender
name: Optional[str] = None
# Tool calls associated with the chat completion
......
......@@ -35,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
......
......@@ -497,11 +497,11 @@
"systems": "systems_7"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
......@@ -718,11 +718,11 @@
},
"nixpkgs_6": {
"locked": {
"lastModified": 1724915739,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=",
"lastModified": 1727675176,
"narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af",
"rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
"type": "github"
},
"original": {
......@@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726626348,
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
"lastModified": 1727836133,
"narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
"rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
"type": "github"
},
"original": {
......@@ -978,17 +978,16 @@
"nixpkgs": "nixpkgs_6"
},
"locked": {
"lastModified": 1727710820,
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=",
"owner": "danieldk",
"repo": "tgi-nix",
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2",
"lastModified": 1727859277,
"narHash": "sha256-AsrPuQqhg8x5RRR3aX0vvDDRQb+HREq2wGxXOpZnWus=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "14196ab62f31d005f46207f7a251f82a81d0a09f",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "moe-kernels-0.5.0",
"repo": "tgi-nix",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
......
......@@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
......
[
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
]
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727556016,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
import pytest
import base64
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60,
}
assert (
response.choices[0].message.content
== "In a bustling city, a chicken named Cluck"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot
......@@ -146,6 +146,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel),
Mistral,
Idefics,
Mllama,
Idefics2(Idefics2),
Ssm,
GptBigcode,
......
......@@ -29,7 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str);
tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
......
......@@ -567,6 +567,7 @@ fn image_tokens(
use HubPreprocessorConfig::*;
match config {
Idefics => "<image>".to_string(),
Mllama => "<|image|>".to_string(),
Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>";
......@@ -618,7 +619,7 @@ fn prepare_input(
use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "^0.19.1"
sentencepiece = "^0.2"
tokenizers = "^0.20"
huggingface-hub = "^0.23"
transformers = "^4.43"
transformers = "^4.45"
einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true }
......
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
......@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
......@@ -76,6 +76,7 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
......@@ -112,7 +113,11 @@ try:
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
......@@ -149,7 +154,7 @@ except ImportError as e:
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(IDEFICSSharded)
__all__.append(IdeficsCausalLM)
MAMBA_AVAILABLE = True
try:
......@@ -316,6 +321,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True,
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
......@@ -1116,7 +1127,7 @@ def get_model(
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
return IDEFICSSharded(
return IdeficsCausalLM(
model_id,
revision,
quantize=quantize,
......@@ -1126,6 +1137,22 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return VlmCausalLM(
......
......@@ -450,6 +450,7 @@ class FlashLlamaLayer(nn.Module):
seqlen,
max_s,
adapter_data,
cross_attention_states,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
......@@ -487,6 +488,7 @@ class FlashLlamaModel(torch.nn.Module):
# Skip fp8 quant for first and last layers
self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
with no_fp8(weights):
self.layers.append(
FlashLlamaLayer(
......@@ -499,22 +501,38 @@ class FlashLlamaModel(torch.nn.Module):
)
)
self.layers.extend(
[
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.mllama import (
FlashLlamaCrossLayer,
)
self.layers.append(
FlashLlamaCrossLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
)
else:
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
)
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
]
)
with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1
......@@ -556,6 +574,7 @@ class FlashLlamaModel(torch.nn.Module):
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data,
cross_attention_states=None,
) -> torch.Tensor:
hidden_states = inputs_embeds
......@@ -579,6 +598,7 @@ class FlashLlamaModel(torch.nn.Module):
seqlen,
max_s,
adapter_data,
cross_attention_states,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -625,6 +645,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
......@@ -639,6 +660,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
......
......@@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
bias=True,
)
self.vocab_size = config.vocab_size
self.vocab_size = config.text_config.vocab_size
self.config = config
text_config = config.text_config
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment