Unverified Commit d18ed5cf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Mllama flash version (#2585)

* Working loading state.

* Preprocessing.

* Working state ? (Broke idefics1 temporarily).

* Cleaner condition.

* Fix idefics.

* Updating config, removing TODO

* Mllama

* Ugrade transformers 4.45

* Flashing mllama.

* Starting to get there.

* Working state.

* Integrations tests for mllama (cutting to 10 tokens because there seems'
to be instability after (meaning size of the batch matters.

* Updating model link.

* Earlier assert.

* Fix vlm ?

* remove log.

* Force ignore all images but last.

* Default dtype bfloat16.

* Update integration test after switch to bf16.

* Remove dead code.

* Removed dead code.

* Upgrade the flake to latest transformers/tokenizers

* Move to hf tgi-nix

* Upgrade to 0.5.0
parent 584b4d7a
...@@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" ...@@ -133,7 +133,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" ...@@ -172,7 +172,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" ...@@ -183,7 +183,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -205,9 +205,9 @@ dependencies = [ ...@@ -205,9 +205,9 @@ dependencies = [
[[package]] [[package]]
name = "autocfg" name = "autocfg"
version = "1.3.0" version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]] [[package]]
name = "av1-grain" name = "av1-grain"
...@@ -316,12 +316,12 @@ dependencies = [ ...@@ -316,12 +316,12 @@ dependencies = [
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.7.6" version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec" checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core 0.4.4", "axum-core 0.4.5",
"bytes", "bytes",
"futures-util", "futures-util",
"http 1.1.0", "http 1.1.0",
...@@ -367,9 +367,9 @@ dependencies = [ ...@@ -367,9 +367,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.4.4" version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00" checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes", "bytes",
...@@ -392,7 +392,7 @@ version = "0.16.0" ...@@ -392,7 +392,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08" checksum = "bdad298231394729042d1f155b93f9fdf0b5ee1aea0b62404c4d7341f7d8fe08"
dependencies = [ dependencies = [
"axum 0.7.6", "axum 0.7.7",
"futures-core", "futures-core",
"futures-util", "futures-util",
"http 1.1.0", "http 1.1.0",
...@@ -456,7 +456,7 @@ dependencies = [ ...@@ -456,7 +456,7 @@ dependencies = [
"regex", "regex",
"rustc-hash", "rustc-hash",
"shlex", "shlex",
"syn 2.0.77", "syn 2.0.79",
"which", "which",
] ]
...@@ -605,9 +605,9 @@ dependencies = [ ...@@ -605,9 +605,9 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.21" version = "1.1.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" checksum = "9540e661f81799159abee814118cc139a2004b3a3aa3ea37724a1b66530b90e0"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
...@@ -704,7 +704,7 @@ dependencies = [ ...@@ -704,7 +704,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -971,7 +971,7 @@ dependencies = [ ...@@ -971,7 +971,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"scratch", "scratch",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60" ...@@ -988,7 +988,7 @@ checksum = "98532a60dedaebc4848cb2cba5023337cc9ea3af16a5b062633fabfd9f18fb60"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1012,7 +1012,7 @@ dependencies = [ ...@@ -1012,7 +1012,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"strsim", "strsim",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" ...@@ -1023,7 +1023,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [ dependencies = [
"darling_core", "darling_core",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1053,7 +1053,7 @@ dependencies = [ ...@@ -1053,7 +1053,7 @@ dependencies = [
"darling", "darling",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1063,7 +1063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc"
dependencies = [ dependencies = [
"derive_builder_core", "derive_builder_core",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" ...@@ -1192,9 +1192,9 @@ checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]] [[package]]
name = "fdeflate" name = "fdeflate"
version = "0.3.4" version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" checksum = "d8090f921a24b04994d9929e204f50b498a33ea6ba559ffaa05e04f7ee7fb5ab"
dependencies = [ dependencies = [
"simd-adler32", "simd-adler32",
] ]
...@@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" ...@@ -1207,9 +1207,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.0.33" version = "1.0.34"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide 0.8.0", "miniz_oxide 0.8.0",
...@@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" ...@@ -1338,7 +1338,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -1864,7 +1864,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c" checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c"
dependencies = [ dependencies = [
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" ...@@ -1884,7 +1884,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2270,7 +2270,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
dependencies = [ dependencies = [
"adler", "adler",
"simd-adler32",
] ]
[[package]] [[package]]
...@@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -2280,6 +2279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [ dependencies = [
"adler2", "adler2",
"simd-adler32",
] ]
[[package]] [[package]]
...@@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" ...@@ -2319,7 +2319,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" ...@@ -2519,7 +2519,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -2599,9 +2599,12 @@ dependencies = [ ...@@ -2599,9 +2599,12 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" checksum = "82881c4be219ab5faaf2ad5e5e5ecdff8c66bd7402ca3160975c93b24961afd1"
dependencies = [
"portable-atomic",
]
[[package]] [[package]]
name = "onig" name = "onig"
...@@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" ...@@ -2654,7 +2657,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -2808,7 +2811,7 @@ dependencies = [ ...@@ -2808,7 +2811,7 @@ dependencies = [
"glob", "glob",
"once_cell", "once_cell",
"opentelemetry 0.21.0", "opentelemetry 0.21.0",
"ordered-float 4.2.2", "ordered-float 4.3.0",
"percent-encoding", "percent-encoding",
"rand", "rand",
"thiserror", "thiserror",
...@@ -2828,7 +2831,7 @@ dependencies = [ ...@@ -2828,7 +2831,7 @@ dependencies = [
"lazy_static", "lazy_static",
"once_cell", "once_cell",
"opentelemetry 0.23.0", "opentelemetry 0.23.0",
"ordered-float 4.2.2", "ordered-float 4.3.0",
"percent-encoding", "percent-encoding",
"rand", "rand",
"thiserror", "thiserror",
...@@ -2851,9 +2854,9 @@ dependencies = [ ...@@ -2851,9 +2854,9 @@ dependencies = [
[[package]] [[package]]
name = "ordered-float" name = "ordered-float"
version = "4.2.2" version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a91171844676f8c7990ce64959210cd2eaef32c2612c50f9fae9f8aaa6065a6" checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537"
dependencies = [ dependencies = [
"num-traits", "num-traits",
] ]
...@@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" ...@@ -2937,7 +2940,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -2988,22 +2991,22 @@ dependencies = [ ...@@ -2988,22 +2991,22 @@ dependencies = [
[[package]] [[package]]
name = "png" name = "png"
version = "0.17.13" version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"crc32fast", "crc32fast",
"fdeflate", "fdeflate",
"flate2", "flate2",
"miniz_oxide 0.7.4", "miniz_oxide 0.8.0",
] ]
[[package]] [[package]]
name = "portable-atomic" name = "portable-atomic"
version = "1.8.0" version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
[[package]] [[package]]
name = "powerfmt" name = "powerfmt"
...@@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3027,7 +3030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -3079,7 +3082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd"
dependencies = [ dependencies = [
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3119,7 +3122,7 @@ dependencies = [ ...@@ -3119,7 +3122,7 @@ dependencies = [
"prost 0.12.6", "prost 0.12.6",
"prost-types", "prost-types",
"regex", "regex",
"syn 2.0.77", "syn 2.0.79",
"tempfile", "tempfile",
] ]
...@@ -3146,7 +3149,7 @@ dependencies = [ ...@@ -3146,7 +3149,7 @@ dependencies = [
"itertools 0.12.1", "itertools 0.12.1",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3205,7 +3208,7 @@ dependencies = [ ...@@ -3205,7 +3208,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3218,7 +3221,7 @@ dependencies = [ ...@@ -3218,7 +3221,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3402,9 +3405,9 @@ dependencies = [ ...@@ -3402,9 +3405,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.5" version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62871f2d65009c0256aed1b9cfeeb8ac272833c404e13d53d400cd0dad7a2ac0" checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f"
dependencies = [ dependencies = [
"bitflags 2.6.0", "bitflags 2.6.0",
] ]
...@@ -3422,14 +3425,14 @@ dependencies = [ ...@@ -3422,14 +3425,14 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.10.6" version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-automata 0.4.7", "regex-automata 0.4.8",
"regex-syntax 0.8.4", "regex-syntax 0.8.5",
] ]
[[package]] [[package]]
...@@ -3443,13 +3446,13 @@ dependencies = [ ...@@ -3443,13 +3446,13 @@ dependencies = [
[[package]] [[package]]
name = "regex-automata" name = "regex-automata"
version = "0.4.7" version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-syntax 0.8.4", "regex-syntax 0.8.5",
] ]
[[package]] [[package]]
...@@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" ...@@ -3460,9 +3463,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.8.4" version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
...@@ -3563,7 +3566,7 @@ dependencies = [ ...@@ -3563,7 +3566,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rust-embed-utils", "rust-embed-utils",
"syn 2.0.77", "syn 2.0.79",
"walkdir", "walkdir",
] ]
...@@ -3686,9 +3689,9 @@ dependencies = [ ...@@ -3686,9 +3689,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-pki-types" name = "rustls-pki-types"
version = "1.8.0" version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55"
[[package]] [[package]]
name = "rustls-webpki" name = "rustls-webpki"
...@@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" ...@@ -3813,7 +3816,7 @@ checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -3840,9 +3843,9 @@ dependencies = [ ...@@ -3840,9 +3843,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "0.6.7" version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [ dependencies = [
"serde", "serde",
] ]
...@@ -4028,7 +4031,7 @@ dependencies = [ ...@@ -4028,7 +4031,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -4050,9 +4053,9 @@ dependencies = [ ...@@ -4050,9 +4053,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.77" version = "2.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
...@@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" ...@@ -4152,9 +4155,9 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]] [[package]]
name = "tempfile" name = "tempfile"
version = "3.12.0" version = "3.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"fastrand", "fastrand",
...@@ -4259,7 +4262,7 @@ version = "2.3.1-dev0" ...@@ -4259,7 +4262,7 @@ version = "2.3.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.7.6", "axum 0.7.7",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap 4.5.18", "clap 4.5.18",
...@@ -4308,7 +4311,7 @@ version = "2.3.1-dev0" ...@@ -4308,7 +4311,7 @@ version = "2.3.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.7.6", "axum 0.7.7",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap 4.5.18", "clap 4.5.18",
...@@ -4357,7 +4360,7 @@ version = "2.3.1-dev0" ...@@ -4357,7 +4360,7 @@ version = "2.3.1-dev0"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
"axum 0.7.6", "axum 0.7.7",
"axum-tracing-opentelemetry", "axum-tracing-opentelemetry",
"base64 0.22.1", "base64 0.22.1",
"clap 4.5.18", "clap 4.5.18",
...@@ -4428,7 +4431,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" ...@@ -4428,7 +4431,7 @@ checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -4533,7 +4536,7 @@ dependencies = [ ...@@ -4533,7 +4536,7 @@ dependencies = [
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
"regex-syntax 0.8.4", "regex-syntax 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
...@@ -4566,7 +4569,7 @@ dependencies = [ ...@@ -4566,7 +4569,7 @@ dependencies = [
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
"regex-syntax 0.8.4", "regex-syntax 0.8.5",
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
...@@ -4612,7 +4615,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" ...@@ -4612,7 +4615,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -4771,7 +4774,7 @@ dependencies = [ ...@@ -4771,7 +4774,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"prost-build", "prost-build",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -4858,7 +4861,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" ...@@ -4858,7 +4861,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -5151,7 +5154,7 @@ dependencies = [ ...@@ -5151,7 +5154,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -5160,7 +5163,7 @@ version = "6.0.0" ...@@ -5160,7 +5163,7 @@ version = "6.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da" checksum = "0b39868d43c011961e04b41623e050aedf2cc93652562ff7935ce0f819aaf2da"
dependencies = [ dependencies = [
"axum 0.7.6", "axum 0.7.7",
"mime_guess", "mime_guess",
"regex", "regex",
"rust-embed", "rust-embed",
...@@ -5189,7 +5192,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" ...@@ -5189,7 +5192,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
...@@ -5290,7 +5293,7 @@ dependencies = [ ...@@ -5290,7 +5293,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
...@@ -5324,7 +5327,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" ...@@ -5324,7 +5327,7 @@ checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
...@@ -5668,9 +5671,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" ...@@ -5668,9 +5671,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]] [[package]]
name = "winnow" name = "winnow"
version = "0.6.19" version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c52ac009d615e79296318c1bcce2d422aaca15ad08515e344feeda07df67a587" checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
...@@ -5703,7 +5706,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" ...@@ -5703,7 +5706,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.77", "syn 2.0.79",
] ]
[[package]] [[package]]
......
...@@ -28,11 +28,17 @@ class ToolCall(BaseModel): ...@@ -28,11 +28,17 @@ class ToolCall(BaseModel):
function: dict function: dict
class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel): class Message(BaseModel):
# Role of the message sender # Role of the message sender
role: str role: str
# Content of the message # Content of the message
content: Optional[str] = None content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender # Optional name of the message sender
name: Optional[str] = None name: Optional[str] = None
# Tool calls associated with the chat completion # Tool calls associated with the chat completion
......
...@@ -35,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware ...@@ -35,6 +35,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b) - [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b) - [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)
- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal) - [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)
- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)
If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models: If the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:
......
...@@ -497,11 +497,11 @@ ...@@ -497,11 +497,11 @@
"systems": "systems_7" "systems": "systems_7"
}, },
"locked": { "locked": {
"lastModified": 1710146030, "lastModified": 1726560853,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github" "type": "github"
}, },
"original": { "original": {
...@@ -718,11 +718,11 @@ ...@@ -718,11 +718,11 @@
}, },
"nixpkgs_6": { "nixpkgs_6": {
"locked": { "locked": {
"lastModified": 1724915739, "lastModified": 1727675176,
"narHash": "sha256-7PgRge4mn5akFvhPwefuaLQGbF5BnmxlwZJEf7CgbrE=", "narHash": "sha256-xIjBFMYldWvj+g8ahxMPofsj+OqxvKJN6YylNHQ7gn4=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "85be051bb60943d3328d91aaf2598798f87e19af", "rev": "a6d0207fea9212d28cd3d487efe6bc699663b93a",
"type": "github" "type": "github"
}, },
"original": { "original": {
...@@ -853,11 +853,11 @@ ...@@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1726626348, "lastModified": 1727836133,
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=", "narHash": "sha256-JE0zciM5IGWvK8J/pE2VldNBf7oyMH5WrU8tZArefbg=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2", "rev": "02321540b0c8000b36889b1b974d1fec585b25a4",
"type": "github" "type": "github"
}, },
"original": { "original": {
...@@ -978,17 +978,16 @@ ...@@ -978,17 +978,16 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1727710820, "lastModified": 1727859277,
"narHash": "sha256-BuSafCxoFQhkp7lnvNtpquxSK43rIbnouL2HypIUC+o=", "narHash": "sha256-AsrPuQqhg8x5RRR3aX0vvDDRQb+HREq2wGxXOpZnWus=",
"owner": "danieldk", "owner": "huggingface",
"repo": "tgi-nix", "repo": "text-generation-inference-nix",
"rev": "4f4dc4b85dd856fd7904e8e3e486a2ff153584a2", "rev": "14196ab62f31d005f46207f7a251f82a81d0a09f",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "danieldk", "owner": "huggingface",
"ref": "moe-kernels-0.5.0", "repo": "text-generation-inference-nix",
"repo": "tgi-nix",
"type": "github" "type": "github"
} }
} }
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
}; };
nix-filter.url = "github:numtide/nix-filter"; nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:danieldk/tgi-nix/moe-kernels-0.5.0"; tgi-nix.url = "github:huggingface/text-generation-inference-nix";
nixpkgs.follows = "tgi-nix/nixpkgs"; nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
rust-overlay = { rust-overlay = {
......
[
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a world where even chickens could dream big,",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727773835,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
]
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a bustling city, a chicken named Cluck",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727556016,
"id": "",
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"object": "chat.completion",
"system_fingerprint": "2.3.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60
}
}
import pytest
import base64
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("meta-llama/Llama-3.2-11B-Vision-Instruct", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 10,
"prompt_tokens": 50,
"total_tokens": 60,
}
assert (
response.choices[0].message.content
== "In a bustling city, a chicken named Cluck"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=10,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert generated_texts[0] == "In a bustling city, a chicken named Cluck"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot
...@@ -146,6 +146,7 @@ pub enum Config { ...@@ -146,6 +146,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Mllama,
Idefics2(Idefics2), Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,
......
...@@ -29,7 +29,7 @@ impl ChatTemplate { ...@@ -29,7 +29,7 @@ impl ChatTemplate {
env.set_unknown_method_callback(pycompat::unknown_method_callback); env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str(); let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception); env.add_function("raise_exception", raise_exception);
tracing::debug!("Loading template: {:#?}", template_str); tracing::debug!("Loading template: {}", template_str);
// leaking env and template_str as read-only, static resources for performance. // leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env) let template = Box::leak(env)
......
...@@ -567,6 +567,7 @@ fn image_tokens( ...@@ -567,6 +567,7 @@ fn image_tokens(
use HubPreprocessorConfig::*; use HubPreprocessorConfig::*;
match config { match config {
Idefics => "<image>".to_string(), Idefics => "<image>".to_string(),
Mllama => "<|image|>".to_string(),
Idefics2(config) => { Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>"; const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>"; const IMAGE: &str = "<image>";
...@@ -618,7 +619,7 @@ fn prepare_input( ...@@ -618,7 +619,7 @@ fn prepare_input(
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0" ...@@ -23,10 +23,10 @@ opentelemetry-api = "^1.25.0"
opentelemetry-exporter-otlp = "^1.25.0" opentelemetry-exporter-otlp = "^1.25.0"
opentelemetry-instrumentation-grpc = "^0.46b0" opentelemetry-instrumentation-grpc = "^0.46b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.2"
tokenizers = "^0.19.1" tokenizers = "^0.20"
huggingface-hub = "^0.23" huggingface-hub = "^0.23"
transformers = "^4.43" transformers = "^4.45"
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }
......
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_ ...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_ ...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13" filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.5.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.65.1 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.66.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.23.5 ; python_version >= "3.9" and python_version < "3.13"
idna==3.7 ; python_version >= "3.9" and python_version < "3.13" idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13" importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13" markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_ ...@@ -32,23 +32,23 @@ opentelemetry-semantic-conventions==0.46b0 ; python_version >= "3.9" and python_
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13" packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13" pygments==2.18.0 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.5.15 ; python_version >= "3.9" and python_version < "3.13" regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==13.7.1 ; python_version >= "3.9" and python_version < "3.13" rich==13.8.1 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==71.1.0 ; python_version >= "3.9" and python_version < "3.13" setuptools==75.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.45.0 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.19.2 ; python_version >= "3.9" and python_version < "3.13" zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
...@@ -76,6 +76,7 @@ FLASH_ATTENTION = True ...@@ -76,6 +76,7 @@ FLASH_ATTENTION = True
try: try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM, FlashDeepseekV2ForCausalLM,
DeepseekV2Config, DeepseekV2Config,
...@@ -112,7 +113,11 @@ try: ...@@ -112,7 +113,11 @@ try:
from text_generation_server.models.custom_modeling.flash_phi_modeling import ( from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM, FlashPhiForCausalLM,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import ( from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration, LlavaNextForConditionalGeneration,
) )
...@@ -149,7 +154,7 @@ except ImportError as e: ...@@ -149,7 +154,7 @@ except ImportError as e:
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashCausalLM) __all__.append(FlashCausalLM)
__all__.append(IDEFICSSharded) __all__.append(IdeficsCausalLM)
MAMBA_AVAILABLE = True MAMBA_AVAILABLE = True
try: try:
...@@ -316,6 +321,12 @@ class ModelType(enum.Enum): ...@@ -316,6 +321,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b", "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True, "multimodal": True,
} }
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals() __GLOBALS = locals()
...@@ -1116,7 +1127,7 @@ def get_model( ...@@ -1116,7 +1127,7 @@ def get_model(
) )
if model_type == IDEFICS: if model_type == IDEFICS:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return IDEFICSSharded( return IdeficsCausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
...@@ -1126,6 +1137,22 @@ def get_model( ...@@ -1126,6 +1137,22 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(
......
...@@ -450,6 +450,7 @@ class FlashLlamaLayer(nn.Module): ...@@ -450,6 +450,7 @@ class FlashLlamaLayer(nn.Module):
seqlen, seqlen,
max_s, max_s,
adapter_data, adapter_data,
cross_attention_states,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
...@@ -487,6 +488,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -487,6 +488,7 @@ class FlashLlamaModel(torch.nn.Module):
# Skip fp8 quant for first and last layers # Skip fp8 quant for first and last layers
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.cross_attention_layers = getattr(config, "cross_attention_layers", [])
with no_fp8(weights): with no_fp8(weights):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
...@@ -499,22 +501,38 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -499,22 +501,38 @@ class FlashLlamaModel(torch.nn.Module):
) )
) )
self.layers.extend( # Skip first and last layers
[ for layer_id in range(1, config.num_hidden_layers - 1):
FlashLlamaLayer( if layer_id in self.cross_attention_layers:
index=layer_id, from text_generation_server.models.custom_modeling.mllama import (
prefix=( FlashLlamaCrossLayer,
f"model.layers.{layer_id}" )
if not prefix
else f"{prefix}.model.layers.{layer_id}" self.layers.append(
), FlashLlamaCrossLayer(
config=config, index=layer_id,
weights=weights, prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
)
else:
self.layers.append(
FlashLlamaLayer(
index=layer_id,
prefix=(
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
weights=weights,
)
) )
# Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1)
]
)
with no_fp8(weights): with no_fp8(weights):
last_layer_id = config.num_hidden_layers - 1 last_layer_id = config.num_hidden_layers - 1
...@@ -556,6 +574,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -556,6 +574,7 @@ class FlashLlamaModel(torch.nn.Module):
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
cross_attention_states=None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -579,6 +598,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -579,6 +598,7 @@ class FlashLlamaModel(torch.nn.Module):
seqlen, seqlen,
max_s, max_s,
adapter_data, adapter_data,
cross_attention_states,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -625,6 +645,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -625,6 +645,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
...@@ -639,6 +660,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -639,6 +660,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
......
...@@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): ...@@ -48,7 +48,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
bias=True, bias=True,
) )
self.vocab_size = config.vocab_size self.vocab_size = config.text_config.vocab_size
self.config = config self.config = config
text_config = config.text_config text_config = config.text_config
......
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
import flash_attn_2_cuda
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(
batch_size, max_num_tiles * target_length, 1
)
attention_mask = (
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(
num_vision_tokens, dim=3
)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value)
.any(dim=-1)
.type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
)
self.gate_ffn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
config=config,
weights=weights,
is_gated=is_gated,
)
for i in range(num_layers)
]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
)
hidden_states = layer_outputs
encoder_states.append(hidden_states)
return hidden_states, encoder_states
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.embedding = TensorParallelEmbedding(
prefix=f"{prefix}.embedding", weights=weights
)
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
# Always gated.
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
# position embedding
embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
)
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
self.tile_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.tile_embedding", weights=weights
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
hidden_state = hidden_state + self.gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
prefix=f"{prefix}.gated_positional_embedding",
config=config,
weights=weights,
)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.pre_tile_positional_embedding",
config=config,
weights=weights,
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.post_tile_positional_embedding",
config=config,
weights=weights,
)
## layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
## encoders
self.transformer = MllamaVisionEncoder(
prefix=f"{prefix}.transformer",
config=config,
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
pixel_values.shape
)
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(pixel_values)
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None:
attention_mask = attention_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
intermediate_hidden_states = [
hidden_state
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
if idx in self.intermediate_layers_indices
]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(
hidden_state, attention_mask=attention_mask
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.q_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(-1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
) = cross_attention_states
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
attn_output = flash_attn_2_cuda.varlen_fwd(
query_states,
key_states,
value_states,
None,
cu_seqlen_q,
cu_seqlen_k,
None,
None,
None, # block_tables
None,
max_q,
max_k,
0.0,
self.softmax_scale,
False,
causal, # Causal
-1, # window_size_left,
-1,
0.0, # softcap
False,
None,
)[0]
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.cross_attn_attn_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
adapter_data,
cross_attention_states, # [ IB, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
@classmethod
def load(cls, *, prefix, weights, eps):
weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
return cls(weight=weight, eps=eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = cu_seqlen_q[-1].item()
max_k = seqlen_k
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
max_q = seqlen_q
max_k = seqlen_k
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
max_q,
max_k,
indices,
)
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment