"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0cdafbf7ecfe0f13335c2ea10fd92391e30c79f0"
Unverified Commit fb1c62e9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Add Mamba`] Adds support for the `Mamba` models (#28094)



* initial-commit

* start cleaning

* small nits

* small nits

* current updates

* add kernels

* small refactoring little step

* add comments

* styling

* nit

* nits

* Style

* Small changes

* Push dummy mambda simple slow

* nit

* Use original names

* Use original names and remove norm

* Updates for inference params

* Style nd updates

* nits

* Match logits

* Add a test

* Add expected generated text

* nits doc, imports and styling

* style

* oups

* dont install kernels, invite users to install the required kernels

* let use use the original packages

* styling

* nits

* fix some copieds

* update doc

* fix-copies

* styling done

* nits

* fix import check

* run but wrong cuda ress

* mamba CUDA works :)

* fix the fast path

* config naming nits

* conversion script is not required at this stage

* finish fixing the fast path: generation make sense now!

* nit

* Let's start working on the CIs

* style

* better style

* more nits

* test nit

* quick fix for now

* nits

* nit

* nit

* nit

* nits

* update test rest

* fixup

* update test

* nit

* some fixes

* nits

* update test values

* fix styling

* nit

* support peft

* integrations tests require torchg

* also add slow markers

* styling

* chose forward wisely

* nits

* update tests

* fix gradient checkpointing

* fixup

* nit

* fix doc

* check copies

* fix the docstring

* fix some more tests

* style

* fix beam search

* add init schene

* update

* nit

* fix

* fixup the doc

* fix the doc

* fixup

* tentative update but slow is no longer good

* nit

* should we always use float32?

* nits

* revert wrong changes

* res in float32

* cleanup

* skip fmt for now

* update generation values

* update test values running original model

* fixup

* update tests + rename inference_params to cache_params + make sure training does not use cache_params

* small nits

* more nits

* fix final CIs

* style

* nit doc

* I hope final doc nits

* nit

* 🫠

* final touch!

* fix torch import

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* Apply suggestions from code review

* fix fix and fix

* fix base model prefix!

* nit

* Update src/transformers/models/mamba/__init__.py

* Update docs/source/en/model_doc/mamba.md
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* nit

---------
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 87a0783d
......@@ -415,6 +415,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.
......
......@@ -388,6 +388,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.
......
......@@ -409,6 +409,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (de Facebook) a été publié dans l'article [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) de Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve et Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (de Facebook) a été publié dans l'article [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) de Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (de Google) a été publié dans l'article [MADLAD-400 : Un ensemble de données multilingue et de niveau document](https://arxiv.org/abs/2309.04662) de Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (de Albert Gu and Tri Dao) publié dans l'article [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) parAlbert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Des modèles de traduction automatique formés avec les données [OPUS](http://opus.nlpl.eu/) par Jörg Tiedemann. Le [cadre Marian](https://marian-nmt.github.io/) est en cours de développement par l'équipe Microsoft Translator.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (de Microsoft Research Asia) a été publié dans l'article [MarkupLM : Pré-entraînement de texte et de langage de balisage pour la compréhension visuellement riche de documents](https://arxiv.org/abs/2110.08518) de Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (de FAIR et UIUC) a été publié dans l'article [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) de Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.
......
......@@ -362,6 +362,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (फेसबुक से) साथ देने वाला पेपर [बियॉन्ड इंग्लिश-सेंट्रिक मल्टीलिंगुअल मशीन ट्रांसलेशन](https://arxiv.org/एब्स/2010.11125) एंजेला फैन, श्रुति भोसले, होल्गर श्वेन्क, झी मा, अहमद अल-किश्की, सिद्धार्थ गोयल, मनदीप बैनेस, ओनूर सेलेबी, गुइल्लाम वेन्जेक, विश्रव चौधरी, नमन गोयल, टॉम बर्च, विटाली लिपचिंस्की, सर्गेई एडुनोव, एडौर्ड द्वारा ग्रेव, माइकल औली, आर्मंड जौलिन द्वारा पोस्ट किया गया।
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao से) Albert Gu and Tri Dao. द्वाराअनुसंधान पत्र [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) के साथ जारी किया गया
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg द्वारा [OPUS](http://opus.nlpl.eu/) डेटा से प्रशिक्षित मशीनी अनुवाद मॉडल पोस्ट किया गया टाइडेमैन द्वारा। [मैरियन फ्रेमवर्क](https://marian-nmt.github.io/) माइक्रोसॉफ्ट ट्रांसलेटर टीम द्वारा विकसित।
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (माइक्रोसॉफ्ट रिसर्च एशिया से) साथ में पेपर [मार्कअपएलएम: विजुअली-रिच डॉक्यूमेंट अंडरस्टैंडिंग के लिए टेक्स्ट और मार्कअप लैंग्वेज का प्री-ट्रेनिंग](https://arxiv.org/abs/2110.08518) जुनलॉन्ग ली, यिहेंग जू, लेई कुई, फुरु द्वारा वी द्वारा पोस्ट किया गया।
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC से) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. द्वाराअनुसंधान पत्र [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) के साथ जारी किया गया
......
......@@ -422,6 +422,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook から) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert から公開された研究論文: [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161)
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook から) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin から公開された研究論文: [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125)
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao から) Albert Gu and Tri Dao. から公開された研究論文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg Tiedemann から. [OPUS](http://opus.nlpl.eu/) を使いながら学習された "Machine translation" (マシントランスレーション) モデル. [Marian Framework](https://marian-nmt.github.io/) はMicrosoft Translator Team が現在開発中です.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia から) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei から公開された研究論文: [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518)
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC から) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. から公開された研究論文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)
......
......@@ -337,6 +337,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook 에서) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 의 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 논문과 함께 발표했습니다.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook 에서) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 의 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 논문과 함께 발표했습니다.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao 에서 제공)은 Albert Gu and Tri Dao.의 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)논문과 함께 발표했습니다.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia 에서) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 의 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 논문과 함께 발표했습니다.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC 에서 제공)은 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.의 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)논문과 함께 발표했습니다.
......
......@@ -361,6 +361,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (来自 Albert Gu and Tri Dao) 伴随论文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) 由 Albert Gu and Tri Dao 发布。
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)**[OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (来自 Microsoft Research Asia) 伴随论文 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 由 Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 发布。
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (来自 FAIR and UIUC) 伴随论文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) 由 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar 发布。
......
......@@ -373,6 +373,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.
......
......@@ -398,6 +398,8 @@
title: M2M100
- local: model_doc/madlad-400
title: MADLAD-400
- local: model_doc/mamba
title: Mamba
- local: model_doc/marian
title: MarianMT
- local: model_doc/markuplm
......
......@@ -180,6 +180,7 @@ Flax), PyTorch, and/or TensorFlow.
| [M-CTC-T](model_doc/mctct) | ✅ | ❌ | ❌ |
| [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ |
| [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ |
| [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ |
| [Marian](model_doc/marian) | ✅ | ✅ | ✅ |
| [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ |
| [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Mamba
## Overview
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
The abstract from the paper is the following:
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
Tips:
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
- Contributions to make the naive path faster are welcome 🤗
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
The original code can be found [here](https://github.com/state-spaces/mamba).
# Usage
### A simple generation example:
```python
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
### Peft finetuning
The slow version is not very stable for training, and the fast one needs `float32`!
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "ArthurZ/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules="all-linear",
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## MambaConfig
[[autodoc]] MambaConfig
## MambaModel
[[autodoc]] MambaModel
- forward
## MambaLMHeadModel
[[autodoc]] MambaForCausalLM
- forward
......@@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
......
......@@ -571,6 +571,7 @@ _import_structure = {
"LxmertTokenizer",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"],
"models.marian": ["MarianConfig"],
"models.markuplm": [
"MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
......@@ -2578,6 +2579,14 @@ else:
"M2M100PreTrainedModel",
]
)
_import_structure["models.mamba"].extend(
[
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
_import_structure["models.markuplm"].extend(
[
......@@ -5370,6 +5379,7 @@ if TYPE_CHECKING:
LxmertTokenizer,
)
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig
from .models.marian import MarianConfig
from .models.markuplm import (
MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -7160,6 +7170,12 @@ if TYPE_CHECKING:
M2M100Model,
M2M100PreTrainedModel,
)
from .models.mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.markuplm import (
MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -3183,7 +3183,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
......@@ -3537,7 +3537,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
......@@ -3943,7 +3943,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
......@@ -4302,7 +4302,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
......
......@@ -128,6 +128,7 @@ from . import (
luke,
lxmert,
m2m_100,
mamba,
marian,
markuplm,
mask2former,
......
......@@ -137,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("luke", "LukeConfig"),
("lxmert", "LxmertConfig"),
("m2m_100", "M2M100Config"),
("mamba", "MambaConfig"),
("marian", "MarianConfig"),
("markuplm", "MarkupLMConfig"),
("mask2former", "Mask2FormerConfig"),
......@@ -373,6 +374,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -614,6 +616,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("lxmert", "LXMERT"),
("m2m_100", "M2M100"),
("madlad-400", "MADLAD-400"),
("mamba", "Mamba"),
("marian", "Marian"),
("markuplm", "MarkupLM"),
("mask2former", "Mask2Former"),
......
......@@ -134,6 +134,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("luke", "LukeModel"),
("lxmert", "LxmertModel"),
("m2m_100", "M2M100Model"),
("mamba", "MambaModel"),
("marian", "MarianModel"),
("markuplm", "MarkupLMModel"),
("mask2former", "Mask2FormerModel"),
......@@ -286,6 +287,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("mamba", "MambaForCausalLM"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
......@@ -367,6 +369,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("mamba", "MambaForCausalLM"),
("marian", "MarianMTModel"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForCausalLM"),
......@@ -439,6 +442,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"),
......
......@@ -233,6 +233,7 @@ else:
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
(
"mbart",
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mamba"] = [
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MAMBA configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"state-spaces/mamba-2.8b": "https://huggingface.co/state-spaces/mamba-2.8b/resolve/main/config.json",
}
class MambaConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MAMBA
[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50280):
Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MambaModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_scale (`float`, *optional*, defaults to 1.0):
Scale used used to scale `dt_proj.bias`.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
Example:
```python
>>> from transformers import MambaConfig, MambaModel
>>> # Initializing a Mamba configuration
>>> configuration = MambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
hidden_size=768,
state_size=16,
num_hidden_layers=32,
layer_norm_epsilon=1e-5,
pad_token_id=0,
bos_token_id=0,
eos_token_id=0,
expand=2,
conv_kernel=4,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=True,
time_step_rank="auto",
time_step_scale=1.0,
time_step_min=0.001,
time_step_max=0.1,
time_step_init_scheme="random",
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
self.num_hidden_layers = num_hidden_layers
self.layer_norm_epsilon = layer_norm_epsilon
self.conv_kernel = conv_kernel
self.expand = expand
self.intermediate_size = int(expand * self.hidden_size)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.use_bias = use_bias
self.use_conv_bias = use_conv_bias
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
self.time_step_scale = time_step_scale
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_init_scheme = time_step_init_scheme
self.time_step_floor = time_step_floor
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment